Source code for ott.geometry.low_rank

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Literal, Optional, Tuple, Union

import jax
import jax.numpy as jnp

from ott.geometry import geometry

__all__ = ["LRCGeometry"]


[docs]@jax.tree_util.register_pytree_node_class class LRCGeometry(geometry.Geometry): """Geometry whose cost is defined by product of two low-rank matrices. Implements geometries that are defined as low rank products, i.e. for which there exists two matrices :math:`A` and :math:`B` of :math:`r` columns such that the cost of the geometry equals :math:`AB^T`. Apart from being faster to apply to a vector, these geometries are characterized by the fact that adding two such geometries should be carried out by concatenating factors, i.e. if :math:`C = AB^T` and :math:`D = EF^T` then :math:`C + D = [A,E][B,F]^T` Args: cost_1: jnp.ndarray<float>[num_a, r] cost_2: jnp.ndarray<float>[num_b, r] bias: constant added to entire cost matrix. scale: Value used to rescale the factors of the low-rank geometry. scale_cost: option to rescale the cost matrix. Implemented scalings are 'max_bound', 'mean' and 'max_cost'. Alternatively, a float factor can be given to rescale the cost such that ``cost_matrix /= scale_cost``. If `True`, use 'mean'. batch_size: optional size of the batch to compute online (without instantiating the matrix) the scale factor ``scale_cost`` of the :attr:`cost_matrix` when ``scale_cost = 'max_cost'``. If `None`, the batch size is set to `1024` or to the largest number of samples between :attr:`cost_1` and :attr:`cost_2` if smaller than `1024`. kwargs: keyword arguments for :class:`~ott.geometry.geometry.Geometry`. """ def __init__( self, cost_1: jnp.ndarray, cost_2: jnp.ndarray, bias: float = 0.0, scale_factor: float = 1.0, scale_cost: Union[bool, int, float, Literal["mean", "max_bound", "max_cost"]] = 1.0, batch_size: Optional[int] = None, **kwargs: Any, ): super().__init__(**kwargs) self._cost_1 = cost_1 self._cost_2 = cost_2 self._bias = bias self._scale_factor = scale_factor self._scale_cost = "mean" if scale_cost is True else scale_cost self.batch_size = batch_size @property def cost_1(self) -> jnp.ndarray: """First factor of the :attr:`cost_matrix`.""" scale_factor = jnp.sqrt(self._scale_factor * self.inv_scale_cost) return scale_factor * self._cost_1 @property def cost_2(self) -> jnp.ndarray: """Second factor of the :attr:`cost_matrix`.""" scale_factor = jnp.sqrt(self._scale_factor * self.inv_scale_cost) return scale_factor * self._cost_2 @property def bias(self) -> float: """Constant offset added to the entire :attr:`cost_matrix`.""" return self._bias * self.inv_scale_cost @property def cost_rank(self) -> int: # noqa: D102 return self._cost_1.shape[1] @property def cost_matrix(self) -> jnp.ndarray: """Materialize the cost matrix.""" return jnp.matmul(self.cost_1, self.cost_2.T) + self.bias @property def shape(self) -> Tuple[int, int]: # noqa: D102 return self._cost_1.shape[0], self._cost_2.shape[0] @property def is_symmetric(self) -> bool: # noqa: D102 return ( self._cost_1.shape[0] == self._cost_2.shape[0] and jnp.all(self._cost_1 == self._cost_2) ) @property def inv_scale_cost(self) -> float: # noqa: D102 if isinstance(self._scale_cost, (int, float, jax.Array)): return 1.0 / self._scale_cost self = self._masked_geom() if self._scale_cost == "max_bound": x_norm = self._cost_1[:, 0].max() y_norm = self._cost_2[:, 1].max() max_bound = x_norm + y_norm + 2 * jnp.sqrt(x_norm * y_norm) return 1.0 / (max_bound + self._bias) if self._scale_cost == "mean": factor1 = jnp.dot(self._n_normed_ones, self._cost_1) factor2 = jnp.dot(self._cost_2.T, self._m_normed_ones) mean = jnp.dot(factor1, factor2) + self._bias return 1.0 / mean if self._scale_cost == "max_cost": return 1.0 / self.compute_max_cost() raise ValueError(f"Scaling {self._scale_cost} not implemented.")
[docs] def apply_square_cost(self, arr: jnp.ndarray, axis: int = 0) -> jnp.ndarray: """Apply elementwise-square of cost matrix to array (vector or matrix).""" (n, m), r = self.shape, self.cost_rank # When applying square of a LRCGeometry, one can either elementwise square # the cost matrix, or instantiate an augmented (rank^2) LRCGeometry # and apply it. First is O(nm), the other is O((n+m)r^2). if n * m < (n + m) * r ** 2: # better use regular apply return super().apply_square_cost(arr, axis) new_cost_1 = self.cost_1[:, :, None] * self.cost_1[:, None, :] new_cost_2 = self.cost_2[:, :, None] * self.cost_2[:, None, :] return LRCGeometry( cost_1=new_cost_1.reshape((n, r ** 2)), cost_2=new_cost_2.reshape((m, r ** 2)) ).apply_cost(arr, axis)
def _apply_cost_to_vec( self, vec: jnp.ndarray, axis: int = 0, fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None, is_linear: bool = False, ) -> jnp.ndarray: """Apply [num_a, num_b] fn(cost) (or transpose) to vector. Args: vec: jnp.ndarray [num_a,] ([num_b,] if axis=1) vector axis: axis on which the reduction is done. fn: function optionally applied to cost matrix element-wise, before the doc product is_linear: Whether ``fn`` is a linear function to enable efficient implementation. See :func:`ott.geometry.geometry.is_linear` for a heuristic to help determine if a function is linear. Returns: A jnp.ndarray corresponding to cost x vector """ def linear_apply( vec: jnp.ndarray, axis: int, fn: Callable[[jnp.ndarray], jnp.ndarray] ) -> jnp.ndarray: c1 = self.cost_1 if axis == 1 else self.cost_2 c2 = self.cost_2 if axis == 1 else self.cost_1 c2 = fn(c2) if fn is not None else c2 bias = fn(self.bias) if fn is not None else self.bias out = jnp.dot(c1, jnp.dot(c2.T, vec)) return out + bias * jnp.sum(vec) * jnp.ones_like(out) if fn is None or is_linear: return linear_apply(vec, axis, fn=fn) return super()._apply_cost_to_vec(vec, axis, fn=fn)
[docs] def compute_max_cost(self) -> float: """Compute the maximum of the :attr:`cost_matrix`. Three cases are taken into account: - If the number of samples of ``cost_1`` and ``cost_2`` are both smaller than 1024 and if ``batch_size`` is `None`, the ``cost_matrix`` is computed to obtain its maximum entry. - If one of the number of samples of ``cost_1`` or ``cost_2`` is larger than 1024 and if ``batch_size`` is `None`, then the maximum of the cost matrix is calculated by batch. The batches are created on the longest axis of the cost matrix and their size is fixed to 1024. - If ``batch_size`` is provided as a float, then the maximum of the cost matrix is calculated by batch. The batches are created on the longest axis of the cost matrix and their size if fixed by ``batch_size``. Returns: Maximum of the cost matrix. """ batch_for_y = self.shape[1] > self.shape[0] n = self.shape[1] if batch_for_y else self.shape[0] p = self._cost_2.shape[1] if batch_for_y else self._cost_1.shape[1] carry = ((self._cost_1, self._cost_2) if batch_for_y else (self._cost_2, self._cost_1)) if self.batch_size: batch_size = min(self.batch_size, n) else: batch_size = min(1024, max(self.shape[0], self.shape[1])) n_batch = n // batch_size def body(carry, slice_idx): cost1, cost2 = carry cost2_slice = jax.lax.dynamic_slice( cost2, (slice_idx * batch_size, 0), (batch_size, p) ) out_slice = jnp.max(jnp.dot(cost2_slice, cost1.T)) return carry, out_slice def finalize(carry): cost1, cost2 = carry return jnp.dot(cost2[n_batch * batch_size:], cost1.T) _, out = jax.lax.scan(body, carry, jnp.arange(n_batch)) last_slice = finalize(carry) max_value = jnp.max(jnp.concatenate((out, last_slice.reshape(-1)))) return max_value + self._bias
[docs] def to_LRCGeometry( self, rank: int = 0, tol: float = 1e-2, rng: Optional[jax.Array] = None, scale: float = 1.0, ) -> "LRCGeometry": """Return self.""" del rank, tol, rng, scale return self
@property def can_LRC(self): # noqa: D102 return True
[docs] def subset( # noqa: D102 self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray], **kwargs: Any ) -> "LRCGeometry": def subset_fn( arr: Optional[jnp.ndarray], ixs: Optional[jnp.ndarray], ) -> jnp.ndarray: return arr if arr is None or ixs is None else arr[jnp.atleast_1d(ixs)] return self._mask_subset_helper( src_ixs, tgt_ixs, fn=subset_fn, propagate_mask=True, **kwargs )
[docs] def mask( # noqa: D102 self, src_mask: Optional[jnp.ndarray], tgt_mask: Optional[jnp.ndarray], mask_value: float = 0., ) -> "LRCGeometry": def mask_fn( arr: Optional[jnp.ndarray], mask: Optional[jnp.ndarray], ) -> Optional[jnp.ndarray]: if arr is None or mask is None: return arr return jnp.where(mask[:, None], arr, mask_value) src_mask = self._normalize_mask(src_mask, self.shape[0]) tgt_mask = self._normalize_mask(tgt_mask, self.shape[1]) return self._mask_subset_helper( src_mask, tgt_mask, fn=mask_fn, propagate_mask=False )
def _mask_subset_helper( self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray], *, fn: Callable[[Optional[jnp.ndarray], Optional[jnp.ndarray]], Optional[jnp.ndarray]], propagate_mask: bool, **kwargs: Any, ) -> "LRCGeometry": (c1, c2, src_mask, tgt_mask, *children), aux_data = self.tree_flatten() c1 = fn(c1, src_ixs) c2 = fn(c2, tgt_ixs) if propagate_mask: src_mask = self._normalize_mask(src_mask, self.shape[0]) tgt_mask = self._normalize_mask(tgt_mask, self.shape[1]) src_mask = fn(src_mask, src_ixs) tgt_mask = fn(tgt_mask, tgt_ixs) aux_data = {**aux_data, **kwargs} return type(self).tree_unflatten( aux_data, [c1, c2, src_mask, tgt_mask] + children ) def __add__(self, other: "LRCGeometry") -> "LRCGeometry": if not isinstance(other, LRCGeometry): return NotImplemented return LRCGeometry( cost_1=jnp.concatenate((self.cost_1, other.cost_1), axis=1), cost_2=jnp.concatenate((self.cost_2, other.cost_2), axis=1), bias=self._bias + other._bias, # already included in `cost_{1,2}` scale_factor=1.0, scale_cost=1.0, ) @property def dtype(self) -> jnp.dtype: # noqa: D102 return self._cost_1.dtype def tree_flatten(self): # noqa: D102 return ( self._cost_1, self._cost_2, self._src_mask, self._tgt_mask, self._epsilon_init, self._bias, self._scale_factor, ), { "scale_cost": self._scale_cost, "batch_size": self.batch_size } @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 c1, c2, src_mask, tgt_mask, epsilon, bias, scale_factor = children return cls( c1, c2, bias=bias, scale_factor=scale_factor, epsilon=epsilon, src_mask=src_mask, tgt_mask=tgt_mask, **aux_data )