Source code for ott.geometry.low_rank

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Literal, Optional, Tuple, Union

import jax
import jax.numpy as jnp

from ott import utils
from ott.geometry import costs, geometry
from ott.math import utils as mu

__all__ = ["LRCGeometry", "LRKGeometry"]


[docs] @jax.tree_util.register_pytree_node_class class LRCGeometry(geometry.Geometry): """Geometry whose cost is defined by product of two low-rank matrices. Implements geometries that are defined as low rank products, i.e. for which there exists two matrices :math:`A` and :math:`B` of :math:`r` columns such that the cost of the geometry equals :math:`AB^T`. Apart from being faster to apply to a vector, these geometries are characterized by the fact that adding two such geometries should be carried out by concatenating factors, i.e. if :math:`C = AB^T` and :math:`D = EF^T` then :math:`C + D = [A,E][B,F]^T` Args: cost_1: Array of shape ``[num_a, r]``. cost_2: Array of shape ``[num_b, r]``. bias: constant added to entire cost matrix. scale: Value used to rescale the factors of the low-rank geometry. scale_cost: option to rescale the cost matrix. Implemented scalings are 'max_bound', 'mean' and 'max_cost'. Alternatively, a float factor can be given to rescale the cost such that ``cost_matrix /= scale_cost``. batch_size: optional size of the batch to compute online (without instantiating the matrix) the scale factor ``scale_cost`` of the :attr:`cost_matrix` when ``scale_cost = 'max_cost'``. If `None`, the batch size is set to `1024` or to the largest number of samples between :attr:`cost_1` and :attr:`cost_2` if smaller than `1024`. kwargs: keyword arguments for :class:`~ott.geometry.geometry.Geometry`. """ def __init__( self, cost_1: jnp.ndarray, cost_2: jnp.ndarray, bias: float = 0.0, scale_factor: float = 1.0, scale_cost: Union[int, float, Literal["mean", "max_bound", "max_cost"]] = 1.0, batch_size: Optional[int] = None, **kwargs: Any, ): super().__init__(**kwargs) self._cost_1 = cost_1 self._cost_2 = cost_2 self._bias = bias self._scale_factor = scale_factor self._scale_cost = scale_cost self.batch_size = batch_size @property def cost_1(self) -> jnp.ndarray: """First factor of the :attr:`cost_matrix`.""" scale_factor = jnp.sqrt(self._scale_factor * self.inv_scale_cost) return scale_factor * self._cost_1 @property def cost_2(self) -> jnp.ndarray: """Second factor of the :attr:`cost_matrix`.""" scale_factor = jnp.sqrt(self._scale_factor * self.inv_scale_cost) return scale_factor * self._cost_2 @property def bias(self) -> float: """Constant offset added to the entire :attr:`cost_matrix`.""" return self._bias * self.inv_scale_cost @property def cost_rank(self) -> int: # noqa: D102 return self._cost_1.shape[1] @property def cost_matrix(self) -> jnp.ndarray: """Materialize the cost matrix.""" return jnp.matmul(self.cost_1, self.cost_2.T) + self.bias @property def shape(self) -> Tuple[int, int]: # noqa: D102 return self._cost_1.shape[0], self._cost_2.shape[0] @property def is_symmetric(self) -> bool: # noqa: D102 return ( self._cost_1.shape[0] == self._cost_2.shape[0] and jnp.all(self._cost_1 == self._cost_2) ) @property def inv_scale_cost(self) -> float: # noqa: D102 if isinstance(self._scale_cost, (int, float, jax.Array)): return 1.0 / self._scale_cost self = self._masked_geom() if self._scale_cost == "max_bound": x_norm = self._cost_1[:, 0].max() y_norm = self._cost_2[:, 1].max() max_bound = x_norm + y_norm + 2 * jnp.sqrt(x_norm * y_norm) return 1.0 / (max_bound + self._bias) if self._scale_cost == "mean": factor1 = jnp.dot(self._n_normed_ones, self._cost_1) factor2 = jnp.dot(self._cost_2.T, self._m_normed_ones) mean = jnp.dot(factor1, factor2) + self._bias return 1.0 / mean if self._scale_cost == "max_cost": return 1.0 / self.compute_max_cost() raise ValueError(f"Scaling {self._scale_cost} not implemented.")
[docs] def apply_square_cost(self, arr: jnp.ndarray, axis: int = 0) -> jnp.ndarray: """Apply elementwise-square of cost matrix to array (vector or matrix).""" (n, m), r = self.shape, self.cost_rank # When applying square of a LRCGeometry, one can either elementwise square # the cost matrix, or instantiate an augmented (rank^2) LRCGeometry # and apply it. First is O(nm), the other is O((n+m)r^2). if n * m < (n + m) * r ** 2: # better use regular apply return super().apply_square_cost(arr, axis) new_cost_1 = self.cost_1[:, :, None] * self.cost_1[:, None, :] new_cost_2 = self.cost_2[:, :, None] * self.cost_2[:, None, :] return LRCGeometry( cost_1=new_cost_1.reshape((n, r ** 2)), cost_2=new_cost_2.reshape((m, r ** 2)) ).apply_cost(arr, axis)
def _apply_cost_to_vec( self, vec: jnp.ndarray, axis: int = 0, fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, is_linear: bool = False, ) -> jnp.ndarray: """Apply [num_a, num_b] fn(cost) (or transpose) to vector. Args: vec: jnp.ndarray [num_a,] ([num_b,] if axis=1) vector axis: axis on which the reduction is done. fn: function optionally applied to cost matrix element-wise, before the doc product is_linear: Whether ``fn`` is a linear function to enable efficient implementation. See :func:`ott.geometry.geometry.is_linear` for a heuristic to help determine if a function is linear. Returns: A jnp.ndarray corresponding to cost x vector """ def linear_apply( vec: jnp.ndarray, axis: int, fn: Callable[[jnp.ndarray], jnp.ndarray] ) -> jnp.ndarray: c1 = self.cost_1 if axis == 1 else self.cost_2 c2 = self.cost_2 if axis == 1 else self.cost_1 c2 = fn(c2) if fn is not None else c2 bias = fn(self.bias) if fn is not None else self.bias out = jnp.dot(c1, jnp.dot(c2.T, vec)) return out + bias * jnp.sum(vec) * jnp.ones_like(out) if fn is None or is_linear: return linear_apply(vec, axis, fn=fn) return super()._apply_cost_to_vec(vec, axis, fn=fn)
[docs] def compute_max_cost(self) -> float: """Compute the maximum of the :attr:`cost_matrix`. Three cases are taken into account: - If the number of samples of ``cost_1`` and ``cost_2`` are both smaller than 1024 and if ``batch_size`` is `None`, the ``cost_matrix`` is computed to obtain its maximum entry. - If one of the number of samples of ``cost_1`` or ``cost_2`` is larger than 1024 and if ``batch_size`` is `None`, then the maximum of the cost matrix is calculated by batch. The batches are created on the longest axis of the cost matrix and their size is fixed to 1024. - If ``batch_size`` is provided as a float, then the maximum of the cost matrix is calculated by batch. The batches are created on the longest axis of the cost matrix and their size if fixed by ``batch_size``. Returns: Maximum of the cost matrix. """ batch_for_y = self.shape[1] > self.shape[0] n = self.shape[1] if batch_for_y else self.shape[0] p = self._cost_2.shape[1] if batch_for_y else self._cost_1.shape[1] carry = ((self._cost_1, self._cost_2) if batch_for_y else (self._cost_2, self._cost_1)) if self.batch_size: batch_size = min(self.batch_size, n) else: batch_size = min(1024, max(self.shape[0], self.shape[1])) n_batch = n // batch_size def body(carry, slice_idx): cost1, cost2 = carry cost2_slice = jax.lax.dynamic_slice( cost2, (slice_idx * batch_size, 0), (batch_size, p) ) out_slice = jnp.max(jnp.dot(cost2_slice, cost1.T)) return carry, out_slice def finalize(carry): cost1, cost2 = carry return jnp.dot(cost2[n_batch * batch_size:], cost1.T) _, out = jax.lax.scan(body, carry, jnp.arange(n_batch)) last_slice = finalize(carry) max_value = jnp.max(jnp.concatenate((out, last_slice.reshape(-1)))) return max_value + self._bias
[docs] def to_LRCGeometry( self, rank: int = 0, tol: float = 1e-2, rng: Optional[jax.Array] = None, scale: float = 1.0, ) -> "LRCGeometry": """Return self.""" del rank, tol, rng, scale return self
@property def can_LRC(self): # noqa: D102 return True
[docs] def subset( # noqa: D102 self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray], **kwargs: Any ) -> "LRCGeometry": def subset_fn( arr: Optional[jnp.ndarray], ixs: Optional[jnp.ndarray], ) -> jnp.ndarray: return arr if arr is None or ixs is None else arr[ixs, ...] return self._mask_subset_helper( src_ixs, tgt_ixs, fn=subset_fn, propagate_mask=True, **kwargs )
[docs] def mask( # noqa: D102 self, src_mask: Optional[jnp.ndarray], tgt_mask: Optional[jnp.ndarray], mask_value: float = 0.0, ) -> "LRCGeometry": def mask_fn( arr: Optional[jnp.ndarray], mask: Optional[jnp.ndarray], ) -> Optional[jnp.ndarray]: if arr is None or mask is None: return arr return jnp.where(mask[:, None], arr, mask_value) src_mask = self._normalize_mask(src_mask, self.shape[0]) tgt_mask = self._normalize_mask(tgt_mask, self.shape[1]) return self._mask_subset_helper( src_mask, tgt_mask, fn=mask_fn, propagate_mask=False )
def _mask_subset_helper( self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray], *, fn: Callable[[Optional[jnp.ndarray], Optional[jnp.ndarray]], Optional[jnp.ndarray]], propagate_mask: bool, **kwargs: Any, ) -> "LRCGeometry": (c1, c2, src_mask, tgt_mask, *children), aux_data = self.tree_flatten() c1 = fn(c1, src_ixs) c2 = fn(c2, tgt_ixs) if propagate_mask: src_mask = self._normalize_mask(src_mask, self.shape[0]) tgt_mask = self._normalize_mask(tgt_mask, self.shape[1]) src_mask = fn(src_mask, src_ixs) tgt_mask = fn(tgt_mask, tgt_ixs) aux_data = {**aux_data, **kwargs} return type(self).tree_unflatten( aux_data, [c1, c2, src_mask, tgt_mask] + children ) def __add__(self, other: "LRCGeometry") -> "LRCGeometry": if not isinstance(other, LRCGeometry): return NotImplemented return LRCGeometry( cost_1=jnp.concatenate((self.cost_1, other.cost_1), axis=1), cost_2=jnp.concatenate((self.cost_2, other.cost_2), axis=1), bias=self._bias + other._bias, # already included in `cost_{1,2}` scale_factor=1.0, scale_cost=1.0, ) @property def dtype(self) -> jnp.dtype: # noqa: D102 return self._cost_1.dtype def tree_flatten(self): # noqa: D102 return ( self._cost_1, self._cost_2, self._src_mask, self._tgt_mask, self._epsilon_init, self._bias, self._scale_factor, ), { "scale_cost": self._scale_cost, "batch_size": self.batch_size } @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 c1, c2, src_mask, tgt_mask, epsilon, bias, scale_factor = children return cls( c1, c2, bias=bias, scale_factor=scale_factor, epsilon=epsilon, src_mask=src_mask, tgt_mask=tgt_mask, **aux_data )
[docs] @jax.tree_util.register_pytree_node_class class LRKGeometry(geometry.Geometry): """Low-rank kernel geometry. .. note:: This constructor is not meant to be called by the user, please use the :meth:`from_pointcloud` method instead. Args: k1: Array of shape ``[num_a, r]`` with positive features. k2: Array of shape ``[num_b, r]`` with positive features. epsilon: Epsilon regularization. kwargs: Keyword arguments for :class:`~ott.geometry.geometry.Geometry`. """ def __init__( self, k1: jnp.ndarray, k2: jnp.ndarray, epsilon: Optional[float] = None, **kwargs: Any ): super().__init__(epsilon=epsilon, relative_epsilon=False, **kwargs) self.k1 = k1 self.k2 = k2
[docs] @classmethod def from_pointcloud( cls, x: jnp.ndarray, y: jnp.ndarray, *, kernel: Literal["gaussian", "arccos"], rank: int = 100, std: float = 1.0, n: int = 1, rng: Optional[jax.Array] = None ) -> "LRKGeometry": r"""Low-rank kernel approximation :cite:`scetbon:20`. Args: x: Array of shape ``[n, d]``. y: Array of shape ``[m, d]``. kernel: Type of the kernel to approximate. rank: Rank of the approximation. std: Depending on the ``kernel`` approximation: - ``'gaussian'`` - scale of the Gibbs kernel. - ``'arccos'`` - standard deviation of the random projections. n: Order of the arc-cosine kernel, see :cite:`cho:09` for reference. rng: Random key used for seeding. Returns: Low-rank kernel geometry. """ rng = utils.default_prng_key(rng) if kernel == "gaussian": r = jnp.maximum( jnp.linalg.norm(x, axis=-1).max(), jnp.linalg.norm(y, axis=-1).max() ) k1 = _gaussian_kernel(rng, x, rank, eps=std, R=r) k2 = _gaussian_kernel(rng, y, rank, eps=std, R=r) eps = std elif kernel == "arccos": k1 = _arccos_kernel(rng, x, rank, n=n, std=std) k2 = _arccos_kernel(rng, y, rank, n=n, std=std) eps = 1.0 else: raise NotImplementedError(kernel) return cls(k1, k2, epsilon=eps)
[docs] def apply_kernel( # noqa: D102 self, scaling: jnp.ndarray, eps: Optional[float] = None, axis: int = 0, ) -> jnp.ndarray: if axis == 0: return self.k2 @ (self.k1.T @ scaling) return self.k1 @ (self.k2.T @ scaling)
@property def kernel_matrix(self) -> jnp.ndarray: # noqa: D102 return self.k1 @ self.k2.T @property def cost_matrix(self) -> jnp.ndarray: # noqa: D102 eps = jnp.finfo(self.dtype).tiny return -self.epsilon * jnp.log(self.kernel_matrix + eps) @property def rank(self) -> int: # noqa: D102 return self.k1.shape[1] @property def shape(self) -> Tuple[int, int]: # noqa: D102 return self.k1.shape[0], self.k2.shape[0] @property def dtype(self) -> jnp.dtype: # noqa: D102 return self.k1.dtype
[docs] def transport_from_potentials( self, f: jnp.ndarray, g: jnp.ndarray ) -> jnp.ndarray: """Not implemented.""" raise ValueError("Not implemented.")
def tree_flatten(self): # noqa: D102 return [self.k1, self.k2, self._epsilon_init], {} @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 return cls(*children, **aux_data)
def _gaussian_kernel( rng: jax.Array, x: jnp.ndarray, n_features: int, eps: float, R: jnp.ndarray, ) -> jnp.ndarray: _, d = x.shape cost_fn = costs.SqEuclidean() y = (R ** 2) / (eps * d) q = y / (2.0 * mu.lambertw(y)) sigma = jnp.sqrt(q * eps * 0.25) u = jax.random.normal(rng, shape=(n_features, d)) * sigma cost = cost_fn.all_pairs(x, u) norm_u = cost_fn.norm(u) tmp = -2.0 * (cost / eps) + (norm_u / (eps * q)) phi = (2 * q) ** (d / 4) * jnp.exp(tmp) return (1.0 / jnp.sqrt(n_features)) * phi def _arccos_kernel( rng: jax.Array, x: jnp.ndarray, n_features: int, n: int, std: float = 1.0, kappa: float = 1e-6, ) -> jnp.ndarray: n_points, d = x.shape c = jnp.sqrt(2) * (std ** (d / 2)) u = jax.random.normal(rng, shape=(n_features, d)) * std tmp = -(1 / 4) * jnp.sum(u ** 2, axis=-1) * (1.0 - (1.0 / (std ** 2))) phi = c * (jnp.maximum(0.0, (x @ u.T)) ** n) * jnp.exp(tmp) return jnp.c_[(1.0 / jnp.sqrt(n_features)) * phi, jnp.full((n_points,), fill_value=kappa)]