Source code for

# Copyright OTT-JAX
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Literal, NamedTuple, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from jax.experimental import checkify

from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn
from import sinkhorn_divergence as sd

__all__ = [

Output = Union[sinkhorn.SinkhornOutput, sd.SinkhornDivergenceOutput]

class ProgOTState(NamedTuple):
  x: jnp.ndarray
  init_potentials: Optional[Tuple[jnp.ndarray, jnp.ndarray]]

[docs] class ProgOTOutput(NamedTuple): """Output of the :class:`ProgOT` solver. Args: prob: Linear problem. alphas: Stepsize schedule of shape ``[num_steps,]``. epsilons: Entropy regularizations of shape ``[num_steps,]``. outputs: OT solver outputs for every step, a struct of arrays. xs: Intermediate interpolations of shape ``[num_steps, n, d]``, if present. """ prob: linear_problem.LinearProblem alphas: jnp.ndarray epsilons: jnp.ndarray outputs: Output xs: Optional[jnp.ndarray] = None
[docs] def transport( self, x: jnp.ndarray, num_steps: Optional[int] = None, return_intermediate: bool = False, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Transport points. Args: x: Array of shape ``[n, d]`` to transport. num_steps: Number of steps. If :obj:`None`, use the full number of steps. return_intermediate: Whether to return intermediate values. Returns: - If ``return_intermediate = True``, return arrays of shape ``[num_steps, n, d]`` and ``[num_steps, n, d]`` corresponding to the interpolations and push-forwards after each step, respectively. - Otherwise, return arrays of shape ``[n, d]`` and ``[n, d]`` corresponding to the last interpolation and push-forward, respectively. """ def body_fn( xy: Tuple[jnp.ndarray, Optional[jnp.ndarray]], it: int ) -> Tuple[Tuple[jnp.ndarray, Optional[jnp.ndarray]], Tuple[ Optional[jnp.ndarray], Optional[jnp.ndarray]]]: x, _ = xy alpha = self.alphas[it] dp = self.get_output(it).to_dual_potentials() t_x = dp.transport(x, forward=True) next_x = _interpolate( x=x, t_x=t_x, alpha=alpha, cost_fn=self.prob.geom.cost_fn ) if return_intermediate: return (next_x, None), (next_x, t_x) return (next_x, t_x), (None, None) if num_steps is None: num_steps = self.num_steps else: assert ( 0 < num_steps <= self.num_steps ), f"Maximum number of steps must be in (0, {self.num_steps}], " \ f"found {num_steps}." state = (x, None) if return_intermediate else (x, jnp.empty_like(x)) xy, xs_ys = jax.lax.scan(body_fn, state, xs=jnp.arange(num_steps)) return xs_ys if return_intermediate else xy
[docs] def get_output(self, step: int) -> Output: r"""Get the OT solver output at a given step. Args: step: Iteration step in :math:`[0, \text{num_steps})`. Returns: The OT solver output at a ``step``. """ return jtu.tree_map(lambda x: x[step], self.outputs)
@property def converged( self ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]: """Convergence at each step. - If :attr:`is_debiased`, return an array of shape ``[num_steps, 3]`` with values corresponding to the convergence of the ``(x, y)``, ``(x, x)`` and ``(y, y)`` problems. - Otherwise, return an array of shape ``[num_steps,]``. """ return jnp.stack(self.outputs.converged, axis=-1) @property def num_iters(self) -> jnp.ndarray: """Number of Sinkhorn iterations within each step. - If :attr:`is_debiased`, return an array of shape ``[num_steps, 3]`` with values corresponding to the number of iterations for the ``(x, y)``, ``(x, x)`` and ``(y, y)`` problems. - Otherwise, return an array of shape ``[num_steps,]``. """ return jnp.array([ self.get_output(it).n_iters for it in range(self.num_steps) ]) @property def num_steps(self) -> int: """Number of :class:`ProgOT` steps.""" return len(self.alphas) @property def is_debiased(self) -> bool: """Whether the OT solver is debiased.""" return isinstance(self.outputs[0], sd.SinkhornDivergenceOutput)
[docs] @jtu.register_pytree_node_class class ProgOT: """Progressive Entropic Optimal Transport solver :cite:`kassraie:24`. Args: alphas: Stepsize schedule of shape ``[num_steps,]``. epsilons: Epsilon regularization schedule of shape ``[num_steps,]``. If :obj:`None`, use the default epsilon at each step. epsilon_scales: Scale for the default epsilon of shape ``[num_steps,]``. If :obj:`None`, don't scale the epsilons. Note that only one of ``epsilons`` and ``epsilon_scales`` can be passed. is_debiased: Whether to use :func:`` or :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`. """ def __init__( self, alphas: jnp.ndarray, *, epsilons: Optional[jnp.ndarray] = None, epsilon_scales: Optional[jnp.ndarray] = None, is_debiased: bool = False, ): if epsilons is not None and epsilon_scales is not None: raise ValueError( "Please pass either `epsilons` or `epsilon_scales`, not both." ) if epsilons is not None: assert len(alphas) == len( epsilons ), "Epsilons have different length than alphas." if epsilon_scales is not None: assert len(alphas) == len( epsilon_scales ), "Epsilon scales have different length than alphas." checkify.check( jnp.all((alphas >= 0.0) & (alphas <= 1.0)), "Alphas must be a sequence with values between zero and one." ) self.alphas = alphas self.epsilons = epsilons self.epsilon_scales = epsilon_scales self.is_debiased = is_debiased def __call__( self, prob: linear_problem.LinearProblem, warm_start: bool = False, **kwargs: Any, ) -> ProgOTOutput: """Run the solver. Args: prob: Linear problem. warm_start: Whether to initialize potentials from the previous step. kwargs: Keyword arguments for :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` or :func:``, depending on :attr:`is_debiased`. Returns: The solver output. """ def body_fn(state: ProgOTState, it: int) -> Tuple[ProgOTState, Tuple[Output, float]]: alpha = self.alphas[it] eps = None if self.epsilons is None else self.epsilons[it] if self.epsilon_scales is not None: # use the default epsilon and scale it geom = pointcloud.PointCloud(state.x, y, cost_fn=cost_fn) eps = self.epsilon_scales[it] * geom.epsilon if self.is_debiased: assert state.init_potentials is None, \ "Warm start is not implemented for debiased." out = _sinkhorn_divergence( state.x, y, cost_fn=cost_fn, eps=eps, **kwargs ) eps = out.geoms[0].epsilon else: out = _sinkhorn( state.x, y, cost_fn=cost_fn, eps=eps, init=state.init_potentials, **kwargs ) eps = out.geom.epsilon t_x = out.to_dual_potentials().transport(state.x, forward=True) next_x = _interpolate(x=state.x, t_x=t_x, alpha=alpha, cost_fn=cost_fn) next_init = ((1.0 - alpha) * out.f, (1.0 - alpha) * out.g) if warm_start else None next_state = ProgOTState(x=next_x, init_potentials=next_init) return next_state, (out, eps) lse_mode = kwargs.get("lse_mode", True) num_steps = len(self.alphas) n, m = prob.geom.shape x, y, cost_fn = prob.geom.x, prob.geom.y, prob.geom.cost_fn _, d = x.shape if warm_start: init_potentials = (jnp.zeros(n), jnp.zeros(m) ) if lse_mode else (jnp.ones(n), jnp.ones(m)) else: init_potentials = None init_state = ProgOTState(x=x, init_potentials=init_potentials) _, (outputs, epsilons) = jax.lax.scan( body_fn, init_state, xs=jnp.arange(num_steps) ) return ProgOTOutput( prob, alphas=self.alphas, epsilons=epsilons, outputs=outputs, ) def tree_flatten(self): # noqa: D102 return (self.alphas, self.epsilons, self.epsilon_scales), { "is_debiased": self.is_debiased, } @classmethod def tree_unflatten( # noqa: D102 cls, aux_data: dict[str, Any], children: Any ) -> "ProgOT": alphas, epsilons, epsilon_scales = children return cls( alphas=alphas, epsilons=epsilons, epsilon_scales=epsilon_scales, **aux_data )
[docs] def get_epsilon_schedule( geom: pointcloud.PointCloud, *, alphas: jnp.ndarray, epsilon_scales: jnp.ndarray, y_eval: jnp.ndarray, start_epsilon_scale: float = 1.0, **kwargs: Any, ) -> jnp.ndarray: """Get the epsilon regularization schedule. See Algorithm 4 in :cite:`kassraie:24` for more information. Args: geom: Point cloud geometry. alphas: Stepsize schedule of shape ``[num_steps,]``. epsilon_scales: Array of shape ``[num_scales,]`` from which to select the best scale of the default epsilon in the ``(y, y)`` point cloud. y_eval: Array of shape ``[k, d]`` from the target distribution used to compute the error. start_epsilon_scale: Constant by which to scale the initial epsilon. kwargs: Keyword arguments for :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`. Returns: The epsilon regularization schedule of shape ``[num_steps,]``. """ def error(epsilon_scale: float) -> float: epsilon = epsilon_scale * geom_end.epsilon out = _sinkhorn(y, y, cost_fn=cost_fn, eps=epsilon, **kwargs) dp = out.to_dual_potentials() y_hat = dp.transport(y_eval, forward=True) return jnp.linalg.norm(y_eval - y_hat) y, cost_fn = geom.y, geom.cost_fn start_eps = start_epsilon_scale * geom.epsilon geom_end = pointcloud.PointCloud(y, y, cost_fn=cost_fn) errors = jax.vmap(error)(epsilon_scales) end_epsilon = epsilon_scales[jnp.argmin(errors)] * geom_end.epsilon mod_alpha = jnp.concatenate([jnp.array([0.0]), alphas]) no_ending_1 = mod_alpha[-1] != 1.0 # e.g. the exp schedule # TODO(michalk8): not jittable if no_ending_1: mod_alpha = jnp.concatenate([mod_alpha, jnp.array([1.0])]) tk = 1.0 - jnp.cumprod(1.0 - mod_alpha) epsilons = end_epsilon * tk + (1.0 - tk) * start_eps epsilons = epsilons[:-1] if no_ending_1: epsilons = epsilons[:-1] return epsilons
[docs] def get_alpha_schedule( kind: Literal["lin", "exp", "quad"], *, num_steps: int ) -> jnp.ndarray: """Get the step size schedule. Convenience wrapper to get a sequence of ``num_steps`` timestamps between 0 and 1, distributed according to the ``kind`` option below. See Section 4 in :cite:`kassraie:24` for more details. Args: kind: The schedule to create: - ``'lin'`` - constant-speed schedule. - ``'exp'`` - decelerating schedule. - ``'quad'`` - accelerating schedule. num_steps: Total number of steps. Returns: The stepsize schedule, array of shape ``[num_steps,]``. """ if kind == "lin": arr = jnp.arange(2, num_steps + 2) arr = 1.0 / (num_steps - arr + 2) elif kind == "exp": arr = jnp.full(num_steps, fill_value=1.0 / jnp.e) elif kind == "quad": arr = jnp.arange(2, num_steps + 2) arr = (2.0 * arr - 1.0) / ((num_steps + 1) ** 2 - (arr - 1) ** 2) else: raise ValueError(f"Invalid stepsize schedule `{kind}`.") return arr
def _sinkhorn( x: jnp.ndarray, y: jnp.ndarray, cost_fn: costs.TICost, eps: Optional[float], init: Optional[Tuple[jnp.ndarray, jnp.ndarray]] = None, **kwargs: Any, ) -> sinkhorn.SinkhornOutput: geom = pointcloud.PointCloud(x, y, cost_fn=cost_fn, epsilon=eps) prob = linear_problem.LinearProblem(geom) solver = sinkhorn.Sinkhorn(**kwargs) return solver(prob, init=init) def _sinkhorn_divergence( x: jnp.ndarray, y: jnp.ndarray, cost_fn: costs.TICost, eps: Optional[float], **kwargs: Any, ) -> sd.SinkhornDivergenceOutput: _, out = sd.sinkhorn_divergence( pointcloud.PointCloud, x, y, cost_fn=cost_fn, epsilon=eps, share_epsilon=False, solve_kwargs=kwargs, ) return out def _interpolate( x: jnp.ndarray, t_x: jnp.ndarray, alpha: float, cost_fn: costs.TICost ) -> jnp.ndarray: xx, weights = jnp.stack([x, t_x]), jnp.array([1.0 - alpha, alpha]) xx, _ = cost_fn.barycenter(weights=weights, xs=xx) return xx