# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Tuple, Union
if TYPE_CHECKING:
from ott.geometry import low_rank
import jax
import jax.numpy as jnp
import jax.scipy as jsp
from ott import utils
from ott.geometry import epsilon_scheduler
from ott.math import utils as mu
__all__ = ["Geometry", "is_linear", "is_affine"]
[docs]@jax.tree_util.register_pytree_node_class
class Geometry:
r"""Base class to define ground costs/kernels used in optimal transport.
Optimal transport problems are intrinsically geometric: they compute an
optimal way to transport mass from one configuration onto another. To define
what is meant by optimality of transport requires defining a cost, of moving
mass from one among several sources, towards one out of multiple targets.
These sources and targets can be provided as points in vectors spaces, grids,
or more generally exclusively described through a (dissimilarity) cost matrix,
or almost equivalently, a (similarity) kernel matrix.
Once that cost or kernel matrix is set, the ``Geometry`` class provides a
basic operations to be run with the Sinkhorn algorithm.
Args:
cost_matrix: Cost matrix of shape ``[n, m]``.
kernel_matrix: Kernel matrix of shape ``[n, m]``.
epsilon: Regularization parameter. If ``scale_epsilon = None`` and either
``relative_epsilon = True`` or ``relative_epsilon = None`` and
``epsilon = None`` in :class:`~ott.geometry.epsilon_scheduler.Epsilon`
is used, ``scale_epsilon`` the is :attr:`mean_cost_matrix`. If
``epsilon = None``, use :math:`0.05`.
relative_epsilon: whether epsilon is passed relative to scale of problem,
here understood the value of the :attr:`mean_cost_matrix`.
scale_cost: option to rescale the cost matrix. Implemented scalings are
'median', 'mean' and 'max_cost'. Alternatively, a float factor can be
given to rescale the cost such that ``cost_matrix /= scale_cost``.
If `True`, use 'mean'.
src_mask: Mask specifying valid rows when computing some statistics of
:attr:`cost_matrix`, see :attr:`src_mask`.
tgt_mask: Mask specifying valid columns when computing some statistics of
:attr:`cost_matrix`, see :attr:`tgt_mask`.
Note:
When defining a :class:`~ott.geometry.geometry.Geometry` through a
``cost_matrix``, it is important to select an ``epsilon`` regularization
parameter that is meaningful. That parameter can be provided by the user,
or assigned a default value through a simple rule,
using the :attr:`mean_cost_matrix`.
"""
def __init__(
self,
cost_matrix: Optional[jnp.ndarray] = None,
kernel_matrix: Optional[jnp.ndarray] = None,
epsilon: Optional[Union[float, epsilon_scheduler.Epsilon]] = None,
relative_epsilon: Optional[bool] = None,
scale_cost: Union[bool, int, float, Literal["mean", "max_cost",
"median"]] = 1.0,
src_mask: Optional[jnp.ndarray] = None,
tgt_mask: Optional[jnp.ndarray] = None,
):
self._cost_matrix = cost_matrix
self._kernel_matrix = kernel_matrix
# needed for `copy_epsilon`, because of the `isinstance` check
self._epsilon_init = epsilon if isinstance(
epsilon, epsilon_scheduler.Epsilon
) else epsilon_scheduler.Epsilon(epsilon)
self._relative_epsilon = relative_epsilon
self._scale_cost = "mean" if scale_cost is True else scale_cost
self._src_mask = src_mask
self._tgt_mask = tgt_mask
@property
def cost_rank(self) -> Optional[int]:
"""Output rank of cost matrix, if any was provided."""
@property
def cost_matrix(self) -> jnp.ndarray:
"""Cost matrix, recomputed from kernel if only kernel was specified."""
if self._cost_matrix is None:
# If no epsilon was passed on to the geometry, then assume it is one by
# default.
eps = jnp.finfo(self._kernel_matrix.dtype).tiny
cost = -jnp.log(self._kernel_matrix + eps)
cost *= self.inv_scale_cost
return cost if self._epsilon_init is None else self.epsilon * cost
return self._cost_matrix * self.inv_scale_cost
@property
def median_cost_matrix(self) -> float:
"""Median of the :attr:`cost_matrix`."""
geom = self._masked_geom(mask_value=jnp.nan)
return jnp.nanmedian(geom.cost_matrix) # will fail for online PC
@property
def mean_cost_matrix(self) -> float:
"""Mean of the :attr:`cost_matrix`."""
tmp = self._masked_geom().apply_cost(self._n_normed_ones).squeeze()
return jnp.sum(tmp * self._m_normed_ones)
@property
def kernel_matrix(self) -> jnp.ndarray:
"""Kernel matrix.
Either provided by user or recomputed from :attr:`cost_matrix`.
"""
if self._kernel_matrix is None:
return jnp.exp(-(self._cost_matrix * self.inv_scale_cost / self.epsilon))
return self._kernel_matrix ** self.inv_scale_cost
@property
def _epsilon(self) -> epsilon_scheduler.Epsilon:
(target, scale_eps, _, _), _ = self._epsilon_init.tree_flatten()
rel = self._relative_epsilon
use_mean_scale = rel is True or (rel is None and target is None)
if scale_eps is None and use_mean_scale:
scale_eps = jax.lax.stop_gradient(self.mean_cost_matrix)
if isinstance(self._epsilon_init, epsilon_scheduler.Epsilon):
return self._epsilon_init.set(scale_epsilon=scale_eps)
return epsilon_scheduler.Epsilon(
target=5e-2 if target is None else target, scale_epsilon=scale_eps
)
@property
def epsilon(self) -> float:
"""Epsilon regularization value."""
return self._epsilon.target
@property
def shape(self) -> Tuple[int, int]:
"""Shape of the geometry."""
mat = (
self._kernel_matrix if self._cost_matrix is None else self._cost_matrix
)
if mat is not None:
return mat.shape
return 0, 0
@property
def can_LRC(self) -> bool:
"""Check quickly if casting geometry as LRC makes sense.
This check is only carried out using basic considerations from the geometry,
not using a rigorous check involving, e.g., SVD.
"""
return False
@property
def is_squared_euclidean(self) -> bool:
"""Whether cost is computed by taking squared Euclidean distance."""
return False
@property
def is_online(self) -> bool:
"""Whether geometry cost/kernel should be recomputed on the fly."""
return False
@property
def is_symmetric(self) -> bool:
"""Whether geometry cost/kernel is a symmetric matrix."""
mat = self.kernel_matrix if self.cost_matrix is None else self.cost_matrix
return (
mat.shape[0] == mat.shape[1] and jnp.all(mat == mat.T)
) if mat is not None else False
@property
def inv_scale_cost(self) -> float:
"""Compute and return inverse of scaling factor for cost matrix."""
if isinstance(self._scale_cost,
(int, float)) or utils.is_jax_array(self._scale_cost):
return 1.0 / self._scale_cost
self = self._masked_geom(mask_value=jnp.nan)
if self._scale_cost == "max_cost":
return 1.0 / jnp.nanmax(self._cost_matrix)
if self._scale_cost == "mean":
return 1.0 / jnp.nanmean(self._cost_matrix)
if self._scale_cost == "median":
return 1.0 / jnp.nanmedian(self._cost_matrix)
raise ValueError(f"Scaling {self._scale_cost} not implemented.")
def _set_scale_cost(self, scale_cost: Union[bool, float, str]) -> "Geometry":
# case when `geom` doesn't have `scale_cost` or doesn't need to be modified
# `False` retains the original scale
if scale_cost is False or scale_cost == self._scale_cost:
return self
children, aux_data = self.tree_flatten()
aux_data["scale_cost"] = scale_cost
return type(self).tree_unflatten(aux_data, children)
[docs] def copy_epsilon(self, other: "Geometry") -> "Geometry":
"""Copy the epsilon parameters from another geometry."""
other_epsilon = other._epsilon
children, aux_data = self.tree_flatten()
new_children = []
for child in children:
if isinstance(child, epsilon_scheduler.Epsilon):
child = child.set(
target=other_epsilon._target_init,
scale_epsilon=other_epsilon._scale_epsilon
)
new_children.append(child)
aux_data["relative_epsilon"] = False
return type(self).tree_unflatten(aux_data, new_children)
# The functions below are at the core of Sinkhorn iterations, they
# are implemented here in their default form, either in lse (using directly
# cost matrices in stabilized form) or kernel mode (using kernel matrices).
[docs] def apply_lse_kernel(
self,
f: jnp.ndarray,
g: jnp.ndarray,
eps: float,
vec: jnp.ndarray = None,
axis: int = 0
) -> jnp.ndarray:
r"""Apply :attr:`kernel_matrix` in log domain.
This function applies the ground geometry's kernel in log domain, using
a stabilized formulation. At a high level, this iteration performs either:
- output = eps * log (K (exp(g / eps) * vec)) (1)
- output = eps * log (K'(exp(f / eps) * vec)) (2)
K is implicitly exp(-cost_matrix/eps).
To carry this out in a stabilized way, we take advantage of the fact that
the entries of the matrix ``f[:,*] + g[*,:] - C`` are all negative, and
therefore their exponential never overflows, to add (and subtract after)
f and g in iterations 1 & 2 respectively.
Args:
f: jnp.ndarray [num_a,] , potential of size num_rows of cost_matrix
g: jnp.ndarray [num_b,] , potential of size num_cols of cost_matrix
eps: float, regularization strength
vec: jnp.ndarray [num_a or num_b,] , when not None, this has the effect of
doing log-Kernel computations with an addition elementwise
multiplication of exp(g / eps) by a vector. This is carried out by
adding weights to the log-sum-exp function, and needs to handle signs
separately.
axis: summing over axis 0 when doing (2), or over axis 1 when doing (1)
Returns:
A jnp.ndarray corresponding to output above, depending on axis.
"""
w_res, w_sgn = self._softmax(f, g, eps, vec, axis)
remove = f if axis == 1 else g
return w_res - jnp.where(jnp.isfinite(remove), remove, 0), w_sgn
[docs] def apply_kernel(
self,
scaling: jnp.ndarray,
eps: Optional[float] = None,
axis: int = 0,
) -> jnp.ndarray:
"""Apply :attr:`kernel_matrix` on positive scaling vector.
Args:
scaling: jnp.ndarray [num_a or num_b] , scaling of size num_rows or
num_cols of kernel_matrix
eps: passed for consistency, not used yet.
axis: standard kernel product if axis is 1, transpose if 0.
Returns:
a jnp.ndarray corresponding to output above, depending on axis.
"""
if eps is None:
kernel = self.kernel_matrix
else:
kernel = self.kernel_matrix ** (self.epsilon / eps)
kernel = kernel if axis == 1 else kernel.T
return jnp.dot(kernel, scaling)
[docs] def marginal_from_potentials(
self,
f: jnp.ndarray,
g: jnp.ndarray,
axis: int = 0,
) -> jnp.ndarray:
"""Output marginal of transportation matrix from potentials.
This applies first lse kernel in the standard way, removes the
correction used to stabilize computations, and lifts this with an exp to
recover either of the marginals corresponding to the transport map induced
by potentials.
Args:
f: jnp.ndarray [num_a,] , potential of size num_rows of cost_matrix
g: jnp.ndarray [num_b,] , potential of size num_cols of cost_matrix
axis: axis along which to integrate, returns marginal on other axis.
Returns:
a vector of marginals of the transport matrix.
"""
h = (f if axis == 1 else g)
z = self.apply_lse_kernel(f, g, self.epsilon, axis=axis)[0]
return jnp.exp((z + h) / self.epsilon)
[docs] def marginal_from_scalings(
self,
u: jnp.ndarray,
v: jnp.ndarray,
axis: int = 0,
) -> jnp.ndarray:
"""Output marginal of transportation matrix from scalings."""
u, v = (v, u) if axis == 0 else (u, v)
return u * self.apply_kernel(v, eps=self.epsilon, axis=axis)
[docs] def transport_from_potentials(
self, f: jnp.ndarray, g: jnp.ndarray
) -> jnp.ndarray:
"""Output transport matrix from potentials."""
return jnp.exp(self._center(f, g) / self.epsilon)
[docs] def transport_from_scalings(
self, u: jnp.ndarray, v: jnp.ndarray
) -> jnp.ndarray:
"""Output transport matrix from pair of scalings."""
return self.kernel_matrix * u[:, jnp.newaxis] * v[jnp.newaxis, :]
# Functions that are not supposed to be changed by inherited classes.
# These are the point of entry for Sinkhorn's algorithm to use a geometry.
[docs] def update_potential(
self,
f: jnp.ndarray,
g: jnp.ndarray,
log_marginal: jnp.ndarray,
iteration: Optional[int] = None,
axis: int = 0,
) -> jnp.ndarray:
"""Carry out one Sinkhorn update for potentials, i.e. in log space.
Args:
f: jnp.ndarray [num_a,] , potential of size num_rows of cost_matrix
g: jnp.ndarray [num_b,] , potential of size num_cols of cost_matrix
log_marginal: targeted marginal
iteration: used to compute epsilon from schedule, if provided.
axis: axis along which the update should be carried out.
Returns:
new potential value, g if axis=0, f if axis is 1.
"""
eps = self._epsilon.at(iteration)
app_lse = self.apply_lse_kernel(f, g, eps, axis=axis)[0]
return eps * log_marginal - jnp.where(jnp.isfinite(app_lse), app_lse, 0)
[docs] def update_scaling(
self,
scaling: jnp.ndarray,
marginal: jnp.ndarray,
iteration: Optional[int] = None,
axis: int = 0,
) -> jnp.ndarray:
"""Carry out one Sinkhorn update for scalings, using kernel directly.
Args:
scaling: jnp.ndarray of num_a or num_b positive values.
marginal: targeted marginal
iteration: used to compute epsilon from schedule, if provided.
axis: axis along which the update should be carried out.
Returns:
new scaling vector, of size num_b if axis=0, num_a if axis is 1.
"""
eps = self._epsilon.at(iteration)
app_kernel = self.apply_kernel(scaling, eps, axis=axis)
return marginal / jnp.where(app_kernel > 0, app_kernel, 1.0)
# Helper functions
def _center(self, f: jnp.ndarray, g: jnp.ndarray) -> jnp.ndarray:
return f[:, jnp.newaxis] + g[jnp.newaxis, :] - self.cost_matrix
def _softmax(
self, f: jnp.ndarray, g: jnp.ndarray, eps: float,
vec: Optional[jnp.ndarray], axis: int
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Apply softmax row or column wise, weighted by vec."""
if vec is not None:
if axis == 0:
vec = vec.reshape((-1, 1))
lse_output = mu.logsumexp(
self._center(f, g) / eps, b=vec, axis=axis, return_sign=True
)
return eps * lse_output[0], lse_output[1]
lse_output = mu.logsumexp(
self._center(f, g) / eps, axis=axis, return_sign=False
)
return eps * lse_output, jnp.array([1.0])
@functools.partial(jax.vmap, in_axes=[None, None, None, 0, None])
def _apply_transport_from_potentials(
self, f: jnp.ndarray, g: jnp.ndarray, vec: jnp.ndarray, axis: int
) -> jnp.ndarray:
"""Apply lse_kernel to arbitrary vector while keeping track of signs."""
lse_res, lse_sgn = self.apply_lse_kernel(
f, g, self.epsilon, vec=vec, axis=axis
)
lse_res += f if axis == 1 else g
return lse_sgn * jnp.exp(lse_res / self.epsilon)
# wrapper to allow default option for axis.
[docs] def apply_transport_from_potentials(
self,
f: jnp.ndarray,
g: jnp.ndarray,
vec: jnp.ndarray,
axis: int = 0
) -> jnp.ndarray:
"""Apply transport matrix computed from potentials to a (batched) vec.
This approach does not instantiate the transport matrix itself, but uses
instead potentials to apply the transport using apply_lse_kernel, therefore
guaranteeing stability and lower memory footprint.
Computations are done in log space, and take advantage of the
(b=..., return_sign=True) optional parameters of logsumexp.
Args:
f: jnp.ndarray [num_a,] , potential of size num_rows of cost_matrix
g: jnp.ndarray [num_b,] , potential of size num_cols of cost_matrix
vec: jnp.ndarray [batch, num_a or num_b], vector that will be multiplied
by transport matrix corresponding to potentials f, g, and geom.
axis: axis to differentiate left (0) or right (1) multiply.
Returns:
ndarray of the size of vec.
"""
if vec.ndim == 1:
return self._apply_transport_from_potentials(
f, g, vec[jnp.newaxis, :], axis
)[0, :]
return self._apply_transport_from_potentials(f, g, vec, axis)
@functools.partial(jax.vmap, in_axes=[None, None, None, 0, None])
def _apply_transport_from_scalings(
self, u: jnp.ndarray, v: jnp.ndarray, vec: jnp.ndarray, axis: int
):
u, v = (u, v * vec) if axis == 1 else (v, u * vec)
return u * self.apply_kernel(v, eps=self.epsilon, axis=axis)
# wrapper to allow default option for axis
[docs] def apply_transport_from_scalings(
self,
u: jnp.ndarray,
v: jnp.ndarray,
vec: jnp.ndarray,
axis: int = 0
) -> jnp.ndarray:
"""Apply transport matrix computed from scalings to a (batched) vec.
This approach does not instantiate the transport matrix itself, but
relies instead on the apply_kernel function.
Args:
u: jnp.ndarray [num_a,] , scaling of size num_rows of cost_matrix
v: jnp.ndarray [num_b,] , scaling of size num_cols of cost_matrix
vec: jnp.ndarray [batch, num_a or num_b], vector that will be multiplied
by transport matrix corresponding to scalings u, v, and geom.
axis: axis to differentiate left (0) or right (1) multiply.
Returns:
ndarray of the size of vec.
"""
if vec.ndim == 1:
return self._apply_transport_from_scalings(
u, v, vec[jnp.newaxis, :], axis
)[0, :]
return self._apply_transport_from_scalings(u, v, vec, axis)
[docs] def potential_from_scaling(self, scaling: jnp.ndarray) -> jnp.ndarray:
"""Compute dual potential vector from scaling vector.
Args:
scaling: vector.
Returns:
a vector of the same size.
"""
return self.epsilon * jnp.log(scaling)
[docs] def scaling_from_potential(self, potential: jnp.ndarray) -> jnp.ndarray:
"""Compute scaling vector from dual potential.
Args:
potential: vector.
Returns:
a vector of the same size.
"""
finite = jnp.isfinite(potential)
return jnp.where(
finite, jnp.exp(jnp.where(finite, potential / self.epsilon, 0.0)), 0.0
)
[docs] def apply_square_cost(self, arr: jnp.ndarray, axis: int = 0) -> jnp.ndarray:
"""Apply elementwise-square of cost matrix to array (vector or matrix).
This function applies the ground geometry's cost matrix, to perform either
output = C arr (if axis=1)
output = C' arr (if axis=0)
where C is [num_a, num_b], when the cost matrix itself is computed as a
squared-Euclidean distance between vectors, and therefore admits an
explicit low-rank factorization.
Args:
arr: array.
axis: axis of the array on which the cost matrix should be applied.
Returns:
An array, [num_b, p] if axis=0 or [num_a, p] if axis=1.
"""
return self.apply_cost(arr, axis=axis, fn=lambda x: x ** 2)
[docs] def apply_cost(
self,
arr: jnp.ndarray,
axis: int = 0,
fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None,
**kwargs: Any
) -> jnp.ndarray:
"""Apply :attr:`cost_matrix` to array (vector or matrix).
This function applies the ground geometry's cost matrix, to perform either
output = C arr (if axis=1)
output = C' arr (if axis=0)
where C is [num_a, num_b]
Args:
arr: jnp.ndarray [num_a or num_b, p], vector that will be multiplied by
the cost matrix.
axis: standard cost matrix if axis=1, transpose if 0
fn: function to apply to cost matrix element-wise before the dot product
kwargs: Keyword arguments for :meth:`_apply_cost_to_vec`.
Returns:
An array, [num_b, p] if axis=0 or [num_a, p] if axis=1
"""
if arr.ndim == 1:
arr = arr.reshape(-1, 1)
app = functools.partial(self._apply_cost_to_vec, axis=axis, fn=fn, **kwargs)
return jax.vmap(app, in_axes=1, out_axes=1)(arr)
def _apply_cost_to_vec(
self,
vec: jnp.ndarray,
axis: int = 0,
fn=None,
**_: Any,
) -> jnp.ndarray:
"""Apply ``[num_a, num_b]`` fn(cost) (or transpose) to vector.
Args:
vec: jnp.ndarray [num_a,] ([num_b,] if axis=1) vector
axis: axis on which the reduction is done.
fn: function optionally applied to cost matrix element-wise, before the
doc product
Returns:
A jnp.ndarray corresponding to cost x vector
"""
matrix = self.cost_matrix.T if axis == 0 else self.cost_matrix
matrix = fn(matrix) if fn is not None else matrix
return jnp.dot(matrix, vec)
[docs] @classmethod
def prepare_divergences(
cls,
*args: Any,
static_b: bool = False,
**kwargs: Any
) -> Tuple["Geometry", ...]:
"""Instantiate 2 (or 3) geometries to compute a Sinkhorn divergence."""
size = 2 if static_b else 3
nones = [None, None, None]
cost_matrices = kwargs.pop("cost_matrix", args)
kernel_matrices = kwargs.pop("kernel_matrix", nones)
cost_matrices = cost_matrices if cost_matrices is not None else nones
return tuple(
cls(cost_matrix=arg1, kernel_matrix=arg2, **kwargs)
for arg1, arg2, _ in zip(cost_matrices, kernel_matrices, range(size))
)
[docs] def to_LRCGeometry(
self,
rank: int = 0,
tol: float = 1e-2,
rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0),
scale: float = 1.
) -> "low_rank.LRCGeometry":
r"""Factorize the cost matrix using either SVD (full) or :cite:`indyk:19`.
When `rank=min(n,m)` or `0` (by default), use :func:`jax.numpy.linalg.svd`.
For other values, use the routine in sublinear time :cite:`indyk:19`.
Uses the implementation of :cite:`scetbon:21`, algorithm 4.
It holds that with probability *0.99*,
:math:`||A - UV||_F^2 \leq || A - A_k ||_F^2 + tol \cdot ||A||_F^2`,
where :math:`A` is ``n x m`` cost matrix, :math:`UV` the factorization
computed in sublinear time and :math:`A_k` the best rank-k approximation.
Args:
rank: Target rank of the :attr:`cost_matrix`.
tol: Tolerance of the error. The total number of sampled points is
:math:`min(n, m,\frac{rank}{tol})`.
rng: The PRNG key to use for initializing the model.
scale: Value used to rescale the factors of the low-rank geometry.
Useful when this geometry is used in the linear term of fused GW.
Returns:
Low-rank geometry.
"""
from ott.geometry import low_rank
assert rank >= 0, f"Rank must be non-negative, got {rank}."
n, m = self.shape
if rank == 0 or rank >= min(n, m):
# TODO(marcocuturi): add hermitian=self.is_symmetric, currently bugging.
u, s, vh = jnp.linalg.svd(
self.cost_matrix,
full_matrices=False,
compute_uv=True,
)
cost_1 = u
cost_2 = (s[:, None] * vh).T
else:
rng1, rng2, rng3, rng4, rng5 = jax.random.split(rng, 5)
n_subset = min(int(rank / tol), n, m)
i_star = jax.random.randint(rng1, shape=(), minval=0, maxval=n)
j_star = jax.random.randint(rng2, shape=(), minval=0, maxval=m)
ci_star = self.subset(i_star, None).cost_matrix.ravel() ** 2 # (m,)
cj_star = self.subset(None, j_star).cost_matrix.ravel() ** 2 # (n,)
p_row = cj_star + ci_star[j_star] + jnp.mean(ci_star) # (n,)
p_row /= jnp.sum(p_row)
row_ixs = jax.random.choice(rng3, n, shape=(n_subset,), p=p_row)
# (n_subset, m)
s = self.subset(row_ixs, None).cost_matrix
s /= jnp.sqrt(n_subset * p_row[row_ixs][:, None])
p_col = jnp.sum(s ** 2, axis=0) # (m,)
p_col /= jnp.sum(p_col)
# (n_subset,)
col_ixs = jax.random.choice(rng4, m, shape=(n_subset,), p=p_col)
# (n_subset, n_subset)
w = s[:, col_ixs] / jnp.sqrt(n_subset * p_col[col_ixs][None, :])
U, _, V = jsp.linalg.svd(w)
U = U[:, :rank] # (n_subset, rank)
U = (s.T @ U) / jnp.linalg.norm(w.T @ U, axis=0) # (m, rank)
_, d, v = jnp.linalg.svd(U.T @ U) # (k,), (k, k)
v = v.T / jnp.sqrt(d)[None, :]
inv_scale = (1. / jnp.sqrt(n_subset))
col_ixs = jax.random.choice(rng5, m, shape=(n_subset,)) # (n_subset,)
# (n, n_subset)
A_trans = self.subset(None, col_ixs).cost_matrix * inv_scale
B = (U[col_ixs, :] @ v * inv_scale) # (n_subset, k)
M = jnp.linalg.inv(B.T @ B) # (k, k)
V = jnp.linalg.multi_dot([A_trans, B, M.T, v.T]) # (n, k)
cost_1 = V
cost_2 = U
return low_rank.LRCGeometry(
cost_1=cost_1,
cost_2=cost_2,
epsilon=self._epsilon_init,
relative_epsilon=self._relative_epsilon,
scale_cost=self._scale_cost,
scale_factor=scale,
)
[docs] def subset(
self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray],
**kwargs: Any
) -> "Geometry":
"""Subset rows or columns of a geometry.
Args:
src_ixs: Row indices. If ``None``, use all rows.
tgt_ixs: Column indices. If ``None``, use all columns.
kwargs: Keyword arguments to override the initialization.
Returns:
The modified geometry.
"""
def subset_fn(
arr: Optional[jnp.ndarray],
src_ixs: Optional[jnp.ndarray],
tgt_ixs: Optional[jnp.ndarray],
) -> Optional[jnp.ndarray]:
if arr is None:
return None
if src_ixs is not None:
arr = arr[jnp.atleast_1d(src_ixs)]
if tgt_ixs is not None:
arr = arr[:, jnp.atleast_1d(tgt_ixs)]
return arr # noqa: RET504
return self._mask_subset_helper(
src_ixs,
tgt_ixs,
fn=subset_fn,
propagate_mask=True,
**kwargs,
)
[docs] def mask(
self,
src_mask: Optional[jnp.ndarray],
tgt_mask: Optional[jnp.ndarray],
mask_value: float = 0.,
) -> "Geometry":
"""Mask rows or columns of a geometry.
The mask is used only when computing some statistics of the
:attr:`cost_matrix`.
- :attr:`mean_cost_matrix`
- :attr:`median_cost_matrix`
- :attr:`inv_scale_cost`
Args:
src_mask: Row mask. Can be specified either as a boolean array of shape
``[num_a,]`` or as an array of indices. If ``None``, no mask is applied.
tgt_mask: Column mask. Can be specified either as a boolean array of shape
``[num_b,]`` or as an array of indices. If ``None``, no mask is applied.
mask_value: Value to use for masking.
Returns:
The masked geometry.
"""
def mask_fn(
arr: Optional[jnp.ndarray],
src_mask: Optional[jnp.ndarray],
tgt_mask: Optional[jnp.ndarray],
) -> Optional[jnp.ndarray]:
if arr is None:
return arr
assert arr.ndim == 2, arr.ndim
if src_mask is not None:
arr = jnp.where(src_mask[:, None], arr, mask_value)
if tgt_mask is not None:
arr = jnp.where(tgt_mask[None, :], arr, mask_value)
return arr # noqa: RET504
src_mask = self._normalize_mask(src_mask, self.shape[0])
tgt_mask = self._normalize_mask(tgt_mask, self.shape[1])
return self._mask_subset_helper(
src_mask, tgt_mask, fn=mask_fn, propagate_mask=False
)
def _mask_subset_helper(
self,
src_ixs: Optional[jnp.ndarray],
tgt_ixs: Optional[jnp.ndarray],
*,
fn: Callable[
[Optional[jnp.ndarray], Optional[jnp.ndarray], Optional[jnp.ndarray]],
Optional[jnp.ndarray]],
propagate_mask: bool,
**kwargs: Any,
) -> "Geometry":
(cost, kernel, eps, src_mask, tgt_mask), aux_data = self.tree_flatten()
cost = fn(cost, src_ixs, tgt_ixs)
kernel = fn(kernel, src_ixs, tgt_ixs)
if propagate_mask:
src_mask = self._normalize_mask(src_mask, self.shape[0])
tgt_mask = self._normalize_mask(tgt_mask, self.shape[1])
src_mask = fn(src_mask, src_ixs, None)
tgt_mask = fn(tgt_mask, tgt_ixs, None)
aux_data = {**aux_data, **kwargs}
return type(self).tree_unflatten(
aux_data, [cost, kernel, eps, src_mask, tgt_mask]
)
@property
def src_mask(self) -> Optional[jnp.ndarray]:
"""Mask of shape ``[num_a,]`` to compute :attr:`cost_matrix` statistics.
Specifically, it is used when computing:
- :attr:`mean_cost_matrix`
- :attr:`median_cost_matrix`
- :attr:`inv_scale_cost`
"""
return self._normalize_mask(self._src_mask, self.shape[0])
@property
def tgt_mask(self) -> Optional[jnp.ndarray]:
"""Mask of shape ``[num_b,]`` to compute :attr:`cost_matrix` statistics.
Specifically, it is used when computing:
- :attr:`mean_cost_matrix`
- :attr:`median_cost_matrix`
- :attr:`inv_scale_cost`
"""
return self._normalize_mask(self._tgt_mask, self.shape[1])
@property
def dtype(self) -> jnp.dtype:
"""The data type."""
return (
self._kernel_matrix if self._cost_matrix is None else self._cost_matrix
).dtype
def _masked_geom(self, mask_value: float = 0.) -> "Geometry":
"""Mask geometry based on :attr:`src_mask` and :attr:`tgt_mask`."""
src_mask, tgt_mask = self.src_mask, self.tgt_mask
if src_mask is None and tgt_mask is None:
return self
return self.mask(src_mask, tgt_mask, mask_value=mask_value)
@property
def _n_normed_ones(self) -> jnp.ndarray:
"""Normalized array of shape ``[num_a,]``."""
mask = self.src_mask
arr = jnp.ones(self.shape[0]) if mask is None else mask
return arr / jnp.sum(arr)
@property
def _m_normed_ones(self) -> jnp.ndarray:
"""Normalized array of shape ``[num_b,]``."""
mask = self.tgt_mask
arr = jnp.ones(self.shape[1]) if mask is None else mask
return arr / jnp.sum(arr)
@staticmethod
def _normalize_mask(mask: Optional[Union[int, jnp.ndarray]],
size: int) -> Optional[jnp.ndarray]:
"""Convert array of indices to a boolean mask."""
if mask is None:
return None
mask = jnp.atleast_1d(mask)
if not jnp.issubdtype(mask, (bool, jnp.bool_)):
mask = jnp.isin(jnp.arange(size), mask)
assert mask.shape == (size,)
return mask
def tree_flatten(self): # noqa: D102
return (
self._cost_matrix, self._kernel_matrix, self._epsilon_init,
self._src_mask, self._tgt_mask
), {
"scale_cost": self._scale_cost,
"relative_epsilon": self._relative_epsilon
}
@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
cost, kernel, eps, src_mask, tgt_mask = children
return cls(
cost, kernel, eps, src_mask=src_mask, tgt_mask=tgt_mask, **aux_data
)
def is_affine(fn) -> bool:
"""Test heuristically if a function is affine."""
x = jnp.arange(10.0)
out = jax.vmap(jax.grad(fn))(x)
return jnp.sum(jnp.diff(jnp.abs(out))) == 0.0
def is_linear(fn) -> bool:
"""Test heuristically if a function is linear."""
return jnp.logical_and(fn(0.0) == 0.0, is_affine(fn))