Source code for ott.geometry.geometry

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Tuple, Union

if TYPE_CHECKING:
  from ott.geometry import low_rank

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.tree_util as jtu

from ott import utils
from ott.geometry import epsilon_scheduler as eps_scheduler
from ott.math import utils as mu

__all__ = ["Geometry"]


[docs] @jtu.register_pytree_node_class class Geometry: r"""Base class to define ground costs/kernels used in optimal transport. Optimal transport problems are intrinsically geometric: they compute an optimal way to transport mass from one configuration onto another. To define what is meant by optimality of transport requires defining a :term:`ground cost`, which quantifies how costly it is to move mass from one among several source locations, towards one out of multiple target locations. These source and target locations can be described as points in vectors spaces, grids, or more generally described through a (dissimilarity) cost matrix, or almost equivalently, a (similarity) kernel matrix. This class describes such a geometry and several useful methods to exploit it. Args: cost_matrix: Cost matrix of shape ``[n, m]``. kernel_matrix: Kernel matrix of shape ``[n, m]``. epsilon: Regularization parameter or a scheduler: - ``epsilon = None`` and ``relative_epsilon = None``, use :math:`0.05 * \text{stddev(cost_matrix)}`. - if ``epsilon`` is a :class:`float` and ``relative_epsilon = None``, it directly corresponds to the regularization strength. - otherwise, ``epsilon`` multiplies the :attr:`mean_cost_matrix` or :attr:`std_cost_matrix`, depending on the value of ``relative_epsilon``. If ``epsilon = None``, the value of :obj:`DEFAULT_EPSILON_SCALE = 0.05 <ott.geometry.epsilon_scheduler.DEFAULT_EPSILON_SCALE>`. will be used. relative_epsilon: Whether ``epsilon`` refers to a fraction of the :attr:`mean_cost_matrix` or :attr:`std_cost_matrix`. scale_cost: option to rescale the cost matrix. Implemented scalings are 'median', 'mean', 'std' and 'max_cost'. Alternatively, a float factor can be given to rescale the cost such that ``cost_matrix /= scale_cost``. Note: When defining a :class:`~ott.geometry.geometry.Geometry` through a ``cost_matrix``, it is important to select an ``epsilon`` regularization parameter that is meaningful. That parameter can be provided by the user, or assigned a default value through a simple rule, using for instance the :attr:`mean_cost_matrix` or the :attr:`std_cost_matrix`. """ # noqa: E501 def __init__( self, cost_matrix: Optional[jnp.ndarray] = None, kernel_matrix: Optional[jnp.ndarray] = None, epsilon: Optional[Union[float, eps_scheduler.Epsilon]] = None, relative_epsilon: Optional[Literal["mean", "std"]] = None, scale_cost: Union[float, Literal["mean", "max_cost", "median", "std"]] = 1.0, ): self._cost_matrix = cost_matrix self._kernel_matrix = kernel_matrix self._epsilon_init = epsilon self._relative_epsilon = relative_epsilon self._scale_cost = scale_cost @property def cost_rank(self) -> Optional[int]: """Output rank of cost matrix, if any was provided.""" @property def cost_matrix(self) -> jnp.ndarray: """Cost matrix, recomputed from kernel if only kernel was specified.""" if self._cost_matrix is None: # If no epsilon was passed on to the geometry, then assume it is one by # default. eps = jnp.finfo(self._kernel_matrix.dtype).tiny cost = -jnp.log(self._kernel_matrix + eps) cost *= self.inv_scale_cost return cost if self._epsilon_init is None else self.epsilon * cost return self._cost_matrix * self.inv_scale_cost @property def median_cost_matrix(self) -> float: """Median of the :attr:`cost_matrix`.""" return jnp.median(self.cost_matrix) @property def mean_cost_matrix(self) -> float: """Mean of the :attr:`cost_matrix`.""" n, m = self.shape tmp = self.apply_cost(jnp.full((n,), fill_value=1.0 / n)) return jnp.sum((1.0 / m) * tmp) @property def std_cost_matrix(self) -> float: r"""Standard deviation of all values stored in :attr:`cost_matrix`. Uses the :meth:`apply_square_cost` to remain applicable to low-rank matrices, through the formula: .. math:: \sigma^2=\frac{1}{nm}\left(\sum_{ij} C_{ij}^2 - (\sum_{ij}C_ij)^2\right). to output :math:`\sigma`. """ n, m = self.shape tmp = self.apply_square_cost(jnp.full((n,), fill_value=1.0 / n)) tmp = jnp.sum((1.0 / m) * tmp) - (self.mean_cost_matrix ** 2) return jnp.sqrt(jax.nn.relu(tmp)) @property def kernel_matrix(self) -> jnp.ndarray: """Kernel matrix. Either provided by user or recomputed from :attr:`cost_matrix`. """ if self._kernel_matrix is None: return jnp.exp(-self._cost_matrix * self.inv_scale_cost / self.epsilon) return self._kernel_matrix ** self.inv_scale_cost @property def epsilon_scheduler(self) -> eps_scheduler.Epsilon: """Epsilon scheduler.""" if isinstance(self._epsilon_init, eps_scheduler.Epsilon): return self._epsilon_init # no relative epsilon if self._relative_epsilon is None: if self._epsilon_init is not None: return eps_scheduler.Epsilon(self._epsilon_init) multiplier = eps_scheduler.DEFAULT_EPSILON_SCALE scale = jax.lax.stop_gradient(self.std_cost_matrix) return eps_scheduler.Epsilon(target=multiplier * scale) if self._relative_epsilon == "std": scale = jax.lax.stop_gradient(self.std_cost_matrix) elif self._relative_epsilon == "mean": scale = jax.lax.stop_gradient(self.mean_cost_matrix) else: raise ValueError(f"Invalid relative epsilon: {self._relative_epsilon}.") multiplier = ( eps_scheduler.DEFAULT_EPSILON_SCALE if self._epsilon_init is None else self._epsilon_init ) return eps_scheduler.Epsilon(target=multiplier * scale) @property def epsilon(self) -> float: """Epsilon regularization value.""" return self.epsilon_scheduler.target @property def shape(self) -> Tuple[int, int]: """Shape of the geometry.""" mat = ( self._kernel_matrix if self._cost_matrix is None else self._cost_matrix ) if mat is not None: return mat.shape return 0, 0 @property def can_LRC(self) -> bool: """Check quickly if casting geometry as LRC makes sense. This check is only carried out using basic considerations from the geometry, not using a rigorous check involving, e.g., SVD. """ return False @property def is_squared_euclidean(self) -> bool: """Whether cost is computed by taking squared Euclidean distance.""" return False @property def is_online(self) -> bool: """Whether geometry cost/kernel should be recomputed on the fly.""" return False @property def is_symmetric(self) -> bool: """Whether geometry cost/kernel is a symmetric matrix.""" n, m = self.shape mat = self.kernel_matrix if self.cost_matrix is None else self.cost_matrix return (n == m) and jnp.all(mat == mat.T) @property def inv_scale_cost(self) -> jnp.ndarray: """Compute and return inverse of scaling factor for cost matrix.""" if self._scale_cost == "max_cost": return 1.0 / jnp.max(self._cost_matrix) if self._scale_cost == "mean": return 1.0 / jnp.mean(self._cost_matrix) if self._scale_cost == "median": return 1.0 / jnp.median(self._cost_matrix) if utils.is_scalar(self._scale_cost): return 1.0 / self._scale_cost raise ValueError(f"Scaling {self._scale_cost} not implemented.")
[docs] def set_scale_cost(self, scale_cost: Union[float, str]) -> "Geometry": """Modify how to rescale of the :attr:`cost_matrix`.""" # case when `geom` doesn't have `scale_cost` or doesn't need to be modified # `False` retains the original scale if scale_cost == self._scale_cost: return self children, aux_data = self.tree_flatten() aux_data["scale_cost"] = scale_cost return type(self).tree_unflatten(aux_data, children)
[docs] def copy_epsilon(self, other: "Geometry") -> "Geometry": """Copy the epsilon parameters from another geometry.""" children, aux_data = self.tree_flatten() new_geom = type(self).tree_unflatten(aux_data, children) new_geom._epsilon_init = other.epsilon_scheduler new_geom._relative_epsilon = other._relative_epsilon # has no effect return new_geom
# The functions below are at the core of Sinkhorn iterations, they # are implemented here in their default form, either in lse (using directly # cost matrices in stabilized form) or kernel mode (using kernel matrices).
[docs] def apply_lse_kernel( self, f: jnp.ndarray, g: jnp.ndarray, eps: float, vec: jnp.ndarray = None, axis: int = 0 ) -> Tuple[jnp.ndarray, jnp.ndarray]: r"""Apply :attr:`kernel_matrix` in log domain. This function applies the ground geometry's kernel in log domain, using a stabilized formulation. At a high level, this iteration performs either: - output = eps * log (K (exp(g / eps) * vec)) (1) - output = eps * log (K'(exp(f / eps) * vec)) (2) K is implicitly exp(-cost_matrix/eps). To carry this out in a stabilized way, we take advantage of the fact that the entries of the matrix ``f[:,*] + g[*,:] - C`` are all negative, and therefore their exponential never overflows, to add (and subtract after) f and g in iterations 1 & 2 respectively. Args: f: jnp.ndarray [num_a,] , potential of size num_rows of cost_matrix g: jnp.ndarray [num_b,] , potential of size num_cols of cost_matrix eps: float, regularization strength vec: jnp.ndarray [num_a or num_b,] , when not None, this has the effect of doing log-Kernel computations with an addition elementwise multiplication of exp(g / eps) by a vector. This is carried out by adding weights to the log-sum-exp function, and needs to handle signs separately. axis: summing over axis 0 when doing (2), or over axis 1 when doing (1) Returns: A jnp.ndarray corresponding to output above, depending on axis. """ w_res, w_sgn = self._softmax(f, g, eps, vec, axis) remove = f if axis == 1 else g return w_res - jnp.where(jnp.isfinite(remove), remove, 0), w_sgn
[docs] def apply_kernel( self, vec: jnp.ndarray, eps: Optional[float] = None, axis: int = 0, ) -> jnp.ndarray: """Apply :attr:`kernel_matrix` on positive scaling vector. Args: vec: jnp.ndarray [num_a or num_b] , scaling of size num_rows or num_cols of kernel_matrix eps: passed for consistency, not used yet. axis: standard kernel product if axis is 1, transpose if 0. Returns: a jnp.ndarray corresponding to output above, depending on axis. """ if eps is None: kernel = self.kernel_matrix else: kernel = self.kernel_matrix ** (self.epsilon / eps) kernel = kernel if axis == 1 else kernel.T return jnp.dot(kernel, vec)
[docs] def marginal_from_potentials( self, f: jnp.ndarray, g: jnp.ndarray, axis: int = 0, ) -> jnp.ndarray: """Output marginal of transportation matrix from potentials. This applies first lse kernel in the standard way, removes the correction used to stabilize computations, and lifts this with an exp to recover either of the marginals corresponding to the transport map induced by potentials. Args: f: jnp.ndarray [num_a,] , potential of size num_rows of cost_matrix g: jnp.ndarray [num_b,] , potential of size num_cols of cost_matrix axis: axis along which to integrate, returns marginal on other axis. Returns: a vector of marginals of the transport matrix. """ h = (f if axis == 1 else g) z = self.apply_lse_kernel(f, g, self.epsilon, axis=axis)[0] return jnp.exp((z + h) / self.epsilon)
[docs] def marginal_from_scalings( self, u: jnp.ndarray, v: jnp.ndarray, axis: int = 0, ) -> jnp.ndarray: """Output marginal of transportation matrix from scalings.""" u, v = (v, u) if axis == 0 else (u, v) return u * self.apply_kernel(v, eps=self.epsilon, axis=axis)
[docs] def transport_from_potentials( self, f: jnp.ndarray, g: jnp.ndarray ) -> jnp.ndarray: """Output transport matrix from potentials.""" return jnp.exp(self._center(f, g) / self.epsilon)
[docs] def transport_from_scalings( self, u: jnp.ndarray, v: jnp.ndarray ) -> jnp.ndarray: """Output transport matrix from pair of scalings.""" return self.kernel_matrix * u[:, jnp.newaxis] * v[jnp.newaxis, :]
# Functions that are not supposed to be changed by inherited classes. # These are the point of entry for Sinkhorn's algorithm to use a geometry.
[docs] def update_potential( self, f: jnp.ndarray, g: jnp.ndarray, log_marginal: jnp.ndarray, iteration: Optional[int] = None, axis: int = 0, ) -> jnp.ndarray: """Carry out one Sinkhorn update for potentials, i.e. in log space. Args: f: jnp.ndarray [num_a,] , potential of size num_rows of cost_matrix g: jnp.ndarray [num_b,] , potential of size num_cols of cost_matrix log_marginal: targeted marginal iteration: used to compute epsilon from schedule, if provided. axis: axis along which the update should be carried out. Returns: new potential value, g if axis=0, f if axis is 1. """ eps = self.epsilon_scheduler(iteration) app_lse = self.apply_lse_kernel(f, g, eps, axis=axis)[0] return eps * log_marginal - jnp.where(jnp.isfinite(app_lse), app_lse, 0)
[docs] def update_scaling( self, scaling: jnp.ndarray, marginal: jnp.ndarray, iteration: Optional[int] = None, axis: int = 0, ) -> jnp.ndarray: """Carry out one Sinkhorn update for scalings, using kernel directly. Args: scaling: jnp.ndarray of num_a or num_b positive values. marginal: targeted marginal iteration: used to compute epsilon from schedule, if provided. axis: axis along which the update should be carried out. Returns: new scaling vector, of size num_b if axis=0, num_a if axis is 1. """ eps = self.epsilon_scheduler(iteration) app_kernel = self.apply_kernel(scaling, eps, axis=axis) return marginal / jnp.where(app_kernel > 0, app_kernel, 1.0)
# Helper functions def _center(self, f: jnp.ndarray, g: jnp.ndarray) -> jnp.ndarray: return f[:, jnp.newaxis] + g[jnp.newaxis, :] - self.cost_matrix def _softmax( self, f: jnp.ndarray, g: jnp.ndarray, eps: float, vec: Optional[jnp.ndarray], axis: int ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Apply softmax row or column wise, weighted by vec.""" if vec is not None: if axis == 0: vec = vec.reshape((-1, 1)) lse_output = mu.logsumexp( self._center(f, g) / eps, b=vec, axis=axis, return_sign=True ) return eps * lse_output[0], lse_output[1] lse_output = mu.logsumexp( self._center(f, g) / eps, axis=axis, return_sign=False ) return eps * lse_output, jnp.array([1.0]) @functools.partial(jax.vmap, in_axes=[None, None, None, 0, None]) def _apply_transport_from_potentials( self, f: jnp.ndarray, g: jnp.ndarray, vec: jnp.ndarray, axis: int ) -> jnp.ndarray: """Apply lse_kernel to arbitrary vector while keeping track of signs.""" lse_res, lse_sgn = self.apply_lse_kernel( f, g, self.epsilon, vec=vec, axis=axis ) lse_res += f if axis == 1 else g return lse_sgn * jnp.exp(lse_res / self.epsilon) # wrapper to allow default option for axis.
[docs] def apply_transport_from_potentials( self, f: jnp.ndarray, g: jnp.ndarray, vec: jnp.ndarray, axis: int = 0 ) -> jnp.ndarray: """Apply transport matrix computed from potentials to a (batched) vec. This approach does not instantiate the transport matrix itself, but uses instead potentials to apply the transport using apply_lse_kernel, therefore guaranteeing stability and lower memory footprint. Computations are done in log space, and take advantage of the (b=..., return_sign=True) optional parameters of logsumexp. Args: f: jnp.ndarray [num_a,] , potential of size num_rows of cost_matrix g: jnp.ndarray [num_b,] , potential of size num_cols of cost_matrix vec: jnp.ndarray [batch, num_a or num_b], vector that will be multiplied by transport matrix corresponding to potentials f, g, and geom. axis: axis to differentiate left (0) or right (1) multiply. Returns: ndarray of the size of vec. """ if vec.ndim == 1: return self._apply_transport_from_potentials( f, g, vec[jnp.newaxis, :], axis )[0, :] return self._apply_transport_from_potentials(f, g, vec, axis)
@functools.partial(jax.vmap, in_axes=[None, None, None, 0, None]) def _apply_transport_from_scalings( self, u: jnp.ndarray, v: jnp.ndarray, vec: jnp.ndarray, axis: int ): u, v = (u, v * vec) if axis == 1 else (v, u * vec) return u * self.apply_kernel(v, eps=self.epsilon, axis=axis) # wrapper to allow default option for axis
[docs] def apply_transport_from_scalings( self, u: jnp.ndarray, v: jnp.ndarray, vec: jnp.ndarray, axis: int = 0 ) -> jnp.ndarray: """Apply transport matrix computed from scalings to a (batched) vec. This approach does not instantiate the transport matrix itself, but relies instead on the apply_kernel function. Args: u: jnp.ndarray [num_a,] , scaling of size num_rows of cost_matrix v: jnp.ndarray [num_b,] , scaling of size num_cols of cost_matrix vec: jnp.ndarray [batch, num_a or num_b], vector that will be multiplied by transport matrix corresponding to scalings u, v, and geom. axis: axis to differentiate left (0) or right (1) multiply. Returns: ndarray of the size of vec. """ if vec.ndim == 1: return self._apply_transport_from_scalings( u, v, vec[jnp.newaxis, :], axis )[0, :] return self._apply_transport_from_scalings(u, v, vec, axis)
[docs] def potential_from_scaling(self, scaling: jnp.ndarray) -> jnp.ndarray: """Compute dual potential vector from scaling vector. Args: scaling: vector. Returns: a vector of the same size. """ return self.epsilon * jnp.log(scaling)
[docs] def scaling_from_potential(self, potential: jnp.ndarray) -> jnp.ndarray: """Compute scaling vector from dual potential. Args: potential: vector. Returns: a vector of the same size. """ finite = jnp.isfinite(potential) return jnp.where( finite, jnp.exp(jnp.where(finite, potential / self.epsilon, 0.0)), 0.0 )
[docs] def apply_square_cost(self, arr: jnp.ndarray, axis: int = 0) -> jnp.ndarray: """Apply elementwise-square of cost matrix to array (vector or matrix). This function applies the ground geometry's cost matrix, to perform either output = C arr (if axis=1) output = C' arr (if axis=0) where C is [num_a, num_b], when the cost matrix itself is computed as a squared-Euclidean distance between vectors, and therefore admits an explicit low-rank factorization. Args: arr: array. axis: axis of the array on which the cost matrix should be applied. Returns: An array, [num_b, p] if axis=0 or [num_a, p] if axis=1. """ return self.apply_cost(arr, axis=axis, fn=lambda x: x ** 2)
[docs] def apply_cost( self, arr: jnp.ndarray, axis: int = 0, fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, is_linear: bool = False, ) -> jnp.ndarray: """Apply :attr:`cost_matrix` to array (vector or matrix). This function applies the ground geometry's cost matrix, to perform either output = C arr (if axis=1) output = C' arr (if axis=0) where C is [num_a, num_b] Args: arr: jnp.ndarray [num_a or num_b, p], vector that will be multiplied by the cost matrix. axis: standard cost matrix if axis=1, transpose if 0 fn: function to apply to cost matrix element-wise before the dot product is_linear: Whether ``fn`` is linear. Returns: An array, [num_b, p] if axis=0 or [num_a, p] if axis=1 """ if arr.ndim == 1: return self._apply_cost_to_vec(arr, axis=axis, fn=fn, is_linear=is_linear) app = functools.partial( self._apply_cost_to_vec, axis=axis, fn=fn, is_linear=is_linear ) return jax.vmap(app, in_axes=1, out_axes=1)(arr)
def _apply_cost_to_vec( self, vec: jnp.ndarray, axis: int = 0, fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, is_linear: bool = False, ) -> jnp.ndarray: """Apply ``[num_a, num_b]`` fn(cost) (or transpose) to vector. Args: vec: jnp.ndarray [num_a,] ([num_b,] if axis=1) vector axis: axis on which the reduction is done. fn: function optionally applied to cost matrix element-wise, before the doc product is_linear: Whether ``fn`` is linear. Returns: A jnp.ndarray corresponding to cost x vector """ del is_linear matrix = self.cost_matrix.T if axis == 0 else self.cost_matrix if fn is not None: matrix = fn(matrix) return jnp.dot(matrix, vec)
[docs] @classmethod def prepare_divergences( cls, *args: Any, static_b: bool = False, **kwargs: Any ) -> Tuple["Geometry", ...]: """Instantiate 2 (or 3) geometries to compute a Sinkhorn divergence.""" size = 2 if static_b else 3 nones = [None, None, None] cost_matrices = kwargs.pop("cost_matrix", args) kernel_matrices = kwargs.pop("kernel_matrix", nones) cost_matrices = cost_matrices if cost_matrices is not None else nones return tuple( cls(cost_matrix=arg1, kernel_matrix=arg2, **kwargs) for arg1, arg2, _ in zip(cost_matrices, kernel_matrices, range(size)) )
[docs] def to_LRCGeometry( self, rank: int = 0, tol: float = 1e-2, rng: Optional[jax.Array] = None, scale: float = 1.0 ) -> "low_rank.LRCGeometry": r"""Factorize the cost matrix using either SVD (full) or :cite:`indyk:19`. When `rank=min(n,m)` or `0` (by default), use :func:`jax.numpy.linalg.svd`. For other values, use the routine in sublinear time :cite:`indyk:19`. Uses the implementation of :cite:`scetbon:21`, algorithm 4. It holds that with probability *0.99*, :math:`||A - UV||_F^2 \leq || A - A_k ||_F^2 + tol \cdot ||A||_F^2`, where :math:`A` is ``n x m`` cost matrix, :math:`UV` the factorization computed in sublinear time and :math:`A_k` the best rank-k approximation. Args: rank: Target rank of the :attr:`cost_matrix`. tol: Tolerance of the error. The total number of sampled points is :math:`min(n, m,\frac{rank}{tol})`. rng: The PRNG key to use for initializing the model. scale: Value used to rescale the factors of the low-rank geometry. Useful when this geometry is used in the linear term of fused GW. Returns: Low-rank geometry. """ from ott.geometry import low_rank assert rank >= 0, f"Rank must be non-negative, got {rank}." n, m = self.shape if rank == 0 or rank >= min(n, m): # TODO(marcocuturi): add hermitian=self.is_symmetric, currently bugging. u, s, vh = jnp.linalg.svd( self.cost_matrix, full_matrices=False, compute_uv=True, ) cost_1 = u cost_2 = (s[:, None] * vh).T else: rng = utils.default_prng_key(rng) rng1, rng2, rng3, rng4, rng5 = jax.random.split(rng, 5) n_subset = min(int(rank / tol), n, m) i_star = jax.random.randint(rng1, shape=(), minval=0, maxval=n) j_star = jax.random.randint(rng2, shape=(), minval=0, maxval=m) ci_star = self.subset(row_ixs=i_star).cost_matrix.ravel() ** 2 # (m,) cj_star = self.subset(col_ixs=j_star).cost_matrix.ravel() ** 2 # (n,) p_row = cj_star + ci_star[j_star] + jnp.mean(ci_star) # (n,) p_row /= jnp.sum(p_row) row_ixs = jax.random.choice(rng3, n, shape=(n_subset,), p=p_row) # (n_subset, m) s = self.subset(row_ixs=row_ixs).cost_matrix s /= jnp.sqrt(n_subset * p_row[row_ixs][:, None]) p_col = jnp.sum(s ** 2, axis=0) # (m,) p_col /= jnp.sum(p_col) # (n_subset,) col_ixs = jax.random.choice(rng4, m, shape=(n_subset,), p=p_col) # (n_subset, n_subset) w = s[:, col_ixs] / jnp.sqrt(n_subset * p_col[col_ixs][None, :]) U, _, V = jsp.linalg.svd(w) U = U[:, :rank] # (n_subset, rank) U = (s.T @ U) / jnp.linalg.norm(w.T @ U, axis=0) # (m, rank) _, d, v = jnp.linalg.svd(U.T @ U) # (k,), (k, k) v = v.T / jnp.sqrt(d)[None, :] inv_scale = (1.0 / jnp.sqrt(n_subset)) col_ixs = jax.random.choice(rng5, m, shape=(n_subset,)) # (n_subset,) # (n, n_subset) A_trans = self.subset(col_ixs=col_ixs).cost_matrix * inv_scale B = (U[col_ixs, :] @ v * inv_scale) # (n_subset, k) M = jnp.linalg.inv(B.T @ B) # (k, k) V = jnp.linalg.multi_dot([A_trans, B, M.T, v.T]) # (n, k) cost_1 = V cost_2 = U return low_rank.LRCGeometry( cost_1=cost_1, cost_2=cost_2, epsilon=self._epsilon_init, relative_epsilon=self._relative_epsilon, scale_cost=self._scale_cost, scale_factor=scale, )
[docs] def subset( self, row_ixs: Optional[jnp.ndarray] = None, col_ixs: Optional[jnp.ndarray] = None ) -> "Geometry": """Subset rows or columns of a geometry. Args: row_ixs: Row indices. If :obj:`None`, use all rows. col_ixs: Column indices. If :obj:`None`, use all columns. Returns: The subsetted geometry. """ (cost, kernel, *rest), aux_data = self.tree_flatten() row_ixs = row_ixs if row_ixs is None else jnp.atleast_1d(row_ixs) col_ixs = col_ixs if col_ixs is None else jnp.atleast_1d(col_ixs) if cost is not None: cost = cost if row_ixs is None else cost[row_ixs] cost = cost if col_ixs is None else cost[:, col_ixs] if kernel is not None: kernel = kernel if row_ixs is None else kernel[row_ixs] kernel = kernel if col_ixs is None else kernel[:, col_ixs] return type(self).tree_unflatten(aux_data, (cost, kernel, *rest))
@property def dtype(self) -> jnp.dtype: """The data type.""" if self._cost_matrix is not None: return self._cost_matrix.dtype return self._kernel_matrix.dtype def tree_flatten(self): # noqa: D102 return ( self._cost_matrix, self._kernel_matrix, self._epsilon_init, ), { "scale_cost": self._scale_cost, "relative_epsilon": self._relative_epsilon, } @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 cost, kernel, epsilon = children return cls(cost, kernel_matrix=kernel, epsilon=epsilon, **aux_data)