Source code for ott.experimental.mmsinkhorn

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, NamedTuple, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np

from ott.geometry import costs, epsilon_scheduler, pointcloud
from ott.math import fixed_point_loop
from ott.math import utils as mu

__all__ = ["MMSinkhornOutput", "MMSinkhorn"]


class MMSinkhornState(NamedTuple):
  potentials: Tuple[jnp.ndarray, ...]
  errors: jnp.ndarray

  def solution_error(
      self,
      cost_t: jnp.ndarray,
      a_s: Tuple[jnp.ndarray, ...],
      epsilon: float,
      norm_error: float = 1.0
  ) -> float:
    coupl_tensor = coupling_tensor(self.potentials, cost_t, epsilon)
    marginals = tensor_marginals(coupl_tensor)
    errors = jnp.array([
        jnp.sum(jnp.abs(a - marginal) ** norm_error) ** (1.0 / norm_error)
        for a, marginal in zip(a_s, marginals)
    ])
    return jnp.sum(errors)

  def set(self, **kwargs: Any) -> "MMSinkhornState":
    """Return a copy of self, with potential overwrites."""
    return self._replace(**kwargs)


[docs] class MMSinkhornOutput(NamedTuple): r"""Output of the MMSinkhorn solver used on :math:`k` point clouds. This class contains both solutions and problem definition of a regularized MM-OT problem involving :math:`k` weighted point clouds of varying sizes, along with methods and properties that can use or describe the solution. Args: potentials: Tuple of :math:`k` optimal dual variables, vectors of sizes equal to the number of points in each of the :math:`k` point clouds. errors: Vector of errors, along iterations. This vector is of size ``max_iterations // inner_iterations`` where those were the parameters passed on to the :class:`~ott.experimental.mmsinkhorn.MMSinkhorn` solver. Follows the conventions used in :attr:`~ott.solvers.linear.sinkhorn.SinkhornOutput.errors` x_s: Tuple of :math:`k` point clouds, ``x_s[i]`` is a matrix of size :math:`n_i \times d` where `d` is common to all point clouds. a_s: Tuple of :math:`k` probability vectors, each of size :math:`n_i`. cost_fns: Cost function, or a tuple of :math:`k(k-1)/2` such instances. epsilon: Entropic regularization used to solve the multimarginal Sinkhorn problem. ent_reg_cost: The regularized optimal transport cost, the linear contribution (dot product between optimal tensor and cost) minus entropy times ``epsilon``. threshold: Convergence threshold used to control the termination of the algorithm. converged: Whether the output corresponds to a solution whose error is below the convergence threshold. inner_iterations: Number of iterations that were run between two computations of errors. """ potentials: Tuple[jnp.ndarray, ...] errors: jnp.ndarray x_s: Optional[Tuple[jnp.ndarray, ...]] = None a_s: Optional[Tuple[jnp.ndarray, ...]] = None cost_fns: Optional[Union[costs.CostFn, Tuple[costs.CostFn, ...]]] = None epsilon: Optional[float] = None ent_reg_cost: Optional[jnp.ndarray] = None threshold: Optional[jnp.ndarray] = None converged: Optional[bool] = None inner_iterations: Optional[int] = None
[docs] def set(self, **kwargs: Any) -> "MMSinkhornOutput": """Return a copy of self, with potential overwrites.""" return self._replace(**kwargs)
@property def n_iters(self) -> int: # noqa: D102 """Total number of iterations that were needed to terminate.""" return jnp.sum(self.errors != -1) * self.inner_iterations @property def cost_t(self) -> jnp.ndarray: """Cost tensor.""" return cost_tensor(self.x_s, self.cost_fns) @property def tensor(self) -> jnp.ndarray: """Transport tensor.""" return jnp.exp( -remove_tensor_sum(self.cost_t, self.potentials) / self.epsilon ) @property def marginals(self) -> Tuple[jnp.ndarray, ...]: """:math:`k` marginal probability weight vectors.""" return tensor_marginals(self.tensor)
[docs] def marginal(self, k: int) -> jnp.ndarray: """Return the marginal probability weight vector at slice :math:`k`.""" return tensor_marginal(self.tensor, k)
@property def transport_mass(self) -> float: """Sum of transport tensor.""" return jnp.sum(self.tensor) @property def shape(self) -> Tuple[int, ...]: """Shape of the transport :attr:`tensor`.""" return tuple(x.shape[0] for x in self.x_s) @property def n_marginals(self) -> int: """Number of marginals.""" return len(self.x_s)
def cost_tensor( x_s: Tuple[jnp.ndarray, ...], cost_fns: Union[costs.CostFn, Tuple[costs.CostFn, ...]] ) -> jnp.ndarray: r"""Create a cost tensor from a tuple of :math:`k` :math:`d`-dim point clouds. Args: x_s: Tuple of :math:`k` point clouds, each described as a :math:`n_i \times d` matrix of batched vectors. cost_fns: Either a single :ott:`ott.geometry.costs.CostFn` object, or a tuple of :math:`k (k-1)/2` of them. Current implementation only works for symmetric and definite cost functions (i.e. such that :math:`c(x, y) = c(y, x)` and :math:`c(x, x) = 0`). """ def c_fn_pair(i: int, j: int) -> costs.CostFn: if isinstance(cost_fns, costs.CostFn): return cost_fns return cost_fns[i * k - (i * (i + 1)) // 2 + j - i - 1] k = len(x_s) # TODO(cuturi) padded version ns = [x.shape[0] for x in x_s] cost_t = jnp.zeros(ns) for i in range(k): for j in range(i + 1, k): cost_m = pointcloud.PointCloud( x_s[i], x_s[j], cost_fn=c_fn_pair(i, j) ).cost_matrix axis = list(range(i)) + list(range(i + 1, j)) + list(range(j + 1, k)) cost_t += jnp.expand_dims(cost_m, axis=axis) return cost_t def remove_tensor_sum( c: jnp.ndarray, u: Tuple[jnp.ndarray, ...] ) -> jnp.ndarray: r"""Remove the tensor sum of :math:`k` vectors to tensor of :math:`k` dims. Args: c: :math:`n_1 \times \cdots n_k` tensor. u: Tuple of :math:`k` vectors, each of size :math:`n_i`. Return: Tensor :math:`c - u[0] \oplus u[1] \oplus ... \oplus u[n]`. """ k = len(u) for i in range(k): c -= jnp.expand_dims(u[i], axis=list(range(i)) + list(range(i + 1, k))) return c def tensor_marginals(coupling: jnp.ndarray) -> Tuple[jnp.ndarray, ...]: return tuple(tensor_marginal(coupling, ix) for ix in range(coupling.ndim)) def tensor_marginal(coupling: jnp.ndarray, slice_index: int) -> jnp.ndarray: k = coupling.ndim axis = list(range(slice_index)) + list(range(slice_index + 1, k)) return coupling.sum(axis=axis)
[docs] @jtu.register_pytree_node_class class MMSinkhorn: r"""Multimarginal Sinkhorn solver, aligns :math:`k \,d`-dim point clouds. This solver implements the entropic multimarginal solver presented in :cite:`benamou:15` and described in :cite:`piran:24`, Algorithm 1. The current implementation follows largely the template of the :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver, with a much reduced set of hyperparameters, controlling the number of iterations and convergence threshold, along with the application of the :cite:`danskin:67` theorem to instantiate the OT cost. The iterations are done by default in log-space. Args: threshold: tolerance used to stop the Sinkhorn iterations. This is typically the deviation between a target marginal and the marginal of the current primal solution. norm_error: power used to define p-norm of error for marginal/target. inner_iterations: the Sinkhorn error is not recomputed at each iteration but every ``inner_iterations`` instead. min_iterations: the minimum number of Sinkhorn iterations carried out before the error is computed and monitored. max_iterations: the maximum number of Sinkhorn iterations. If ``max_iterations`` is equal to ``min_iterations``, Sinkhorn iterations are run by default using a :func:`~jax.lax.scan` loop rather than a custom, unroll-able :func:`~jax.lax.while_loop` that monitors convergence. In that case the error is not monitored and the ``converged`` flag will return :obj:`False` as a consequence. use_danskin: when :obj:`True`, it is assumed the entropy regularized cost is evaluated using optimal potentials that are frozen, i.e. whose gradients have been stopped. This is useful when carrying out first order differentiation, and is only valid mathematically when the algorithm has converged with a low tolerance. """ def __init__( self, threshold: float = 1e-3, norm_error: float = 1.0, inner_iterations: int = 10, min_iterations: int = 0, max_iterations: int = 2000, use_danskin: bool = True, ): self.threshold = threshold self.inner_iterations = inner_iterations self.min_iterations = min_iterations self.max_iterations = max_iterations self.norm_error = norm_error self.use_danskin = use_danskin def __call__( self, x_s: Tuple[jnp.ndarray, ...], a_s: Optional[Tuple[jnp.ndarray, ...]] = None, cost_fns: Optional[Union[costs.CostFn, Tuple[costs.CostFn, ...]]] = None, epsilon: Optional[float] = None ) -> MMSinkhornOutput: r"""Solve multimarginal OT for :math:`k` :math:`d`-dim point clouds. Takes :math:`k` weighted :math:`d`-dim point clouds and computes their multimarginal optimal transport tensor. The :math:`d` dimensional point clouds are stored in ``x_s``, along with :math:`k` probability vectors, stored in ``a_s``, as well as a :class:`~ott.geometry.costs.CostFn` instance (or :math:`k(k-1)/2` of them, one for each pair of point clouds ``x_s[i]`` and ``x_s[j]``, ``i<j``.) The solver also uses ``epsilon`` as an input, with the default rule set to one twentieth of the standard deviation of the all values stored in the cost tensor resulting from these inputs. Args: x_s: Tuple of :math:`k` point clouds, ``x_s[i]`` is a matrix of size :math:`n_i \times d` where :math:`d` is a dimension common to all point clouds. a_s: Tuple of :math:`k` probability vectors, each of size :math:`n_i`. cost_fns: Instance of :class:`~ott.geometry.costs.CostFn`, or a tuple of :math:`k(k-1)/2` such instances. Note that the solver currently assumes that these cost functions are symmetric. The cost function at index :math:`i(k-\tfrac{i+1}{2})+j-i-1` will be used to compare point cloud ``x_s[i]`` with point cloud ``x_s[j]``. epsilon: entropic regularization used to solve the multimarginal Sinkhorn problem. Returns: Multimarginal Sinkhorn output. """ n_s = [x.shape[0] for x in x_s] if cost_fns is None: cost_fns = costs.SqEuclidean() elif isinstance(cost_fns, Tuple): assert len(cost_fns) == (len(n_s) * (len(n_s) - 1)) // 2 # Default to uniform probability weights for each point cloud. if a_s is None: a_s = [jnp.ones(n) / n for n in n_s] else: # Case in which user passes ``None`` weights within tuple. a_s = [(jnp.ones(n) / n if a is None else a) for a, n in zip(a_s, n_s)] assert len(n_s) == len(a_s), (len(n_s), len(a_s)) for n, a in zip(n_s, a_s): assert n == a.shape[0], (n, a.shape[0]) cost_t = cost_tensor(x_s, cost_fns) state = self.init_state(n_s) if epsilon is None: epsilon = epsilon_scheduler.DEFAULT_EPSILON_SCALE * jnp.std(cost_t) const = cost_t, a_s, epsilon out = run(const, self, state) return out.set(x_s=x_s, a_s=a_s, cost_fns=cost_fns, epsilon=epsilon)
[docs] def init_state(self, n_s: Tuple[int, ...]) -> MMSinkhornState: """Return the initial state of the loop.""" errors = -jnp.ones((self.outer_iterations, 1)) potentials = tuple(jnp.zeros(n) for n in n_s) return MMSinkhornState(potentials=potentials, errors=errors)
def _converged(self, state: MMSinkhornState, iteration: int) -> bool: err = state.errors[iteration // self.inner_iterations - 1, 0] return jnp.logical_and(iteration > 0, err < self.threshold) def _diverged(self, state: MMSinkhornState, iteration: int) -> bool: err = state.errors[iteration // self.inner_iterations - 1, 0] return jnp.logical_not(jnp.isfinite(err)) def _continue(self, state: MMSinkhornState, iteration: int) -> bool: """Continue while not(converged) and not(diverged).""" return jnp.logical_and( jnp.logical_not(self._diverged(state, iteration)), jnp.logical_not(self._converged(state, iteration)) ) @property def outer_iterations(self) -> int: """Upper bound on number of times inner_iterations are carried out. This integer can be used to set constant array sizes to track the algorithm progress, notably errors. """ return np.ceil(self.max_iterations / self.inner_iterations).astype(int) def tree_flatten(self): # noqa: D102 aux = vars(self).copy() aux.pop("threshold") return [self.threshold], aux @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 return cls(**aux_data, threshold=children[0])
def run( const: Tuple[jnp.ndarray, Tuple[jnp.ndarray, ...], float], solver: MMSinkhorn, state: MMSinkhornState ) -> MMSinkhornOutput: def cond_fn( iteration: int, const: Tuple[jnp.ndarray, Tuple[jnp.ndarray, ...], float], state: MMSinkhornState ) -> bool: del const return solver._continue(state, iteration) def body_fn( iteration: int, const: Tuple[jnp.ndarray, Tuple[jnp.ndarray, ...], float], state: MMSinkhornState, compute_error: bool ) -> MMSinkhornState: cost_t, a_s, epsilon = const k = len(a_s) def one_slice(potentials: Tuple[jnp.ndarray, ...], l: int, a: jnp.ndarray): pot = potentials[l] axis = list(range(l)) + list(range(l + 1, k)) app_lse = mu.softmin( remove_tensor_sum(cost_t, potentials), epsilon, axis=axis ) pot += epsilon * jnp.log(a) + jnp.where(jnp.isfinite(app_lse), app_lse, 0) return potentials[:l] + (pot,) + potentials[l + 1:] potentials = state.potentials for l in range(k): potentials = one_slice(potentials, l, a_s[l]) state = state.set(potentials=potentials) err = jax.lax.cond( jnp.logical_or( iteration == solver.max_iterations - 1, jnp.logical_and(compute_error, iteration >= solver.min_iterations) ), lambda state, c, a, e: state.solution_error(c, a, e, solver.norm_error), lambda *_: jnp.inf, state, cost_t, a_s, epsilon ) errors = state.errors.at[iteration // solver.inner_iterations, :].set(err) return state.set(errors=errors) fix_point = fixed_point_loop.fixpoint_iter_backprop state = fix_point( cond_fn, body_fn, solver.min_iterations, solver.max_iterations, solver.inner_iterations, const, state ) converged = jnp.logical_and( jnp.logical_not(jnp.any(jnp.isnan(state.errors))), state.errors[-1, 0] < solver.threshold ) out = MMSinkhornOutput( potentials=state.potentials, errors=state.errors, threshold=solver.threshold, converged=converged, inner_iterations=solver.inner_iterations ) # Compute cost if solver.use_danskin: potentials = [jax.lax.stop_gradient(pot) for pot in out.potentials] else: potentials = out.potentials cost_t, a_s, epsilon = const ent_reg_cost = 0.0 for potential, a in zip(potentials, a_s): pot = jnp.where(jnp.isfinite(potential), potential, 0) ent_reg_cost += jnp.sum(pot * a) ent_reg_cost += epsilon * ( 1 - jnp.sum(coupling_tensor(potentials, cost_t, epsilon)) ) return out.set(ent_reg_cost=ent_reg_cost) def coupling_tensor( potentials: Tuple[jnp.ndarray], cost_t: jnp.ndarray, epsilon: float ) -> jnp.ndarray: return jnp.exp(-remove_tensor_sum(cost_t, potentials) / epsilon)