Source code for ott.initializers.linear.initializers

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Any, Dict, Optional, Sequence, Tuple

import jax
import jax.numpy as jnp

from ott.geometry import pointcloud
from ott.problems.linear import linear_problem

__all__ = [
    "DefaultInitializer", "GaussianInitializer", "SortingInitializer",
    "SubsampleInitializer"
]


@jax.tree_util.register_pytree_node_class
class SinkhornInitializer(abc.ABC):
  """Base class for Sinkhorn initializers."""

  @abc.abstractmethod
  def init_dual_a(
      self,
      ot_prob: linear_problem.LinearProblem,
      lse_mode: bool,
      rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0)
  ) -> jnp.ndarray:
    """Initialize Sinkhorn potential/scaling f_u.

    Args:
      ot_prob: Linear OT problem.
      lse_mode: Return potential if ``True``, scaling if ``False``.
      rng: Random number generator for stochastic initializers.

    Returns:
      potential/scaling, array of size ``[n,]``.
    """

  @abc.abstractmethod
  def init_dual_b(
      self,
      ot_prob: linear_problem.LinearProblem,
      lse_mode: bool,
      rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0)
  ) -> jnp.ndarray:
    """Initialize Sinkhorn potential/scaling g_v.

    Args:
      ot_prob: Linear OT problem.
      lse_mode: Return potential if ``True``, scaling if ``False``.
      rng: Random number generator for stochastic initializers.

    Returns:
      potential/scaling, array of size ``[m,]``.
    """

  def __call__(
      self,
      ot_prob: linear_problem.LinearProblem,
      a: Optional[jnp.ndarray],
      b: Optional[jnp.ndarray],
      lse_mode: bool,
      rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0),
  ) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Initialize Sinkhorn potentials/scalings f_u and g_v.

    Args:
      ot_prob: Linear OT problem.
      a: Initial potential/scaling f_u.
        If ``None``, it will be initialized using :meth:`init_dual_a`.
      b: Initial potential/scaling g_v.
        If ``None``, it will be initialized using :meth:`init_dual_b`.
      lse_mode: Return potentials if ``True``, scalings if ``False``.
      rng: Random number generator for stochastic initializers.

    Returns:
      The initial potentials/scalings.
    """
    n, m = ot_prob.geom.shape
    rng_x, rng_y = jax.random.split(rng, 2)
    if a is None:
      a = self.init_dual_a(ot_prob, lse_mode=lse_mode, rng=rng_x)
    if b is None:
      b = self.init_dual_b(ot_prob, lse_mode=lse_mode, rng=rng_y)

    assert a.shape == (
        n,
    ), f"Expected `f_u` to have shape `{n,}`, found `{a.shape}`."
    assert b.shape == (
        m,
    ), f"Expected `g_v` to have shape `{m,}`, found `{b.shape}`."

    # cancel dual variables for zero weights
    a = jnp.where(ot_prob.a > 0., a, -jnp.inf if lse_mode else 0.)
    b = jnp.where(ot_prob.b > 0., b, -jnp.inf if lse_mode else 0.)

    return a, b

  def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:  # noqa: D102
    return [], {}

  @classmethod
  def tree_unflatten(  # noqa: D102
      cls, aux_data: Dict[str, Any], children: Sequence[Any]
  ) -> "SinkhornInitializer":
    return cls(*children, **aux_data)


[docs]@jax.tree_util.register_pytree_node_class class DefaultInitializer(SinkhornInitializer): """Default initialization of Sinkhorn dual potentials/primal scalings."""
[docs] def init_dual_a( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0) ) -> jnp.ndarray: del rng return jnp.zeros_like(ot_prob.a) if lse_mode else jnp.ones_like(ot_prob.a)
[docs] def init_dual_b( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0) ) -> jnp.ndarray: del rng return jnp.zeros_like(ot_prob.b) if lse_mode else jnp.ones_like(ot_prob.b)
[docs]@jax.tree_util.register_pytree_node_class class GaussianInitializer(DefaultInitializer): """Gaussian initializer :cite:`thornton2022rethinking:22`. Compute Gaussian approximations of each :class:`~ott.geometry.pointcloud.PointCloud`, then compute closed from Kantorovich potential between Gaussian approximations using Brenier's theorem (adapt convex/Brenier potential to Kantorovich). Use this Gaussian potential to initialize Sinkhorn potentials/scalings. """
[docs] def init_dual_a( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0) ) -> jnp.ndarray: # import Gaussian here due to circular imports from ott.tools.gaussian_mixture import gaussian del rng assert isinstance( ot_prob.geom, pointcloud.PointCloud ), "Gaussian initializer valid only for pointcloud geoms." x, y = ot_prob.geom.x, ot_prob.geom.y a, b = ot_prob.a, ot_prob.b gaussian_a = gaussian.Gaussian.from_samples(x, weights=a) gaussian_b = gaussian.Gaussian.from_samples(y, weights=b) # Brenier potential for cost ||x-y||^2/2, multiply by two for ||x-y||^2 f_potential = 2 * gaussian_a.f_potential(dest=gaussian_b, points=x) f_potential = f_potential - jnp.mean(f_potential) return f_potential if lse_mode else ot_prob.geom.scaling_from_potential( f_potential )
[docs]@jax.tree_util.register_pytree_node_class class SortingInitializer(DefaultInitializer): """Sorting initializer :cite:`thornton2022rethinking:22`. Solve non-regularized OT problem via sorting, then compute potential through iterated minimum on C-transform and use this potential to initialize regularized potential. Args: vectorized_update: Whether to use vectorized loop. tolerance: DualSort convergence threshold. max_iter: Max DualSort steps. """ def __init__( self, vectorized_update: bool = True, tolerance: float = 1e-2, max_iter: int = 100 ): super().__init__() self.tolerance = tolerance self.max_iter = max_iter self.vectorized_update = vectorized_update def _init_sorting_dual( self, modified_cost: jnp.ndarray, init_f: jnp.ndarray ) -> jnp.ndarray: """Run DualSort algorithm. Args: modified_cost: cost matrix minus diagonal column-wise. init_f: potential f, array of size n. This is the starting potential, which is then updated to make the init potential, so an init of an init. Returns: potential f, array of size n. """ def body_fn( state: Tuple[jnp.ndarray, float, int] ) -> Tuple[jnp.ndarray, float, int]: prev_f, _, it = state new_f = fn(prev_f, modified_cost) diff = jnp.sum((new_f - prev_f) ** 2) it += 1 return new_f, diff, it def cond_fn(state: Tuple[jnp.ndarray, float, int]) -> bool: _, diff, it = state return jnp.logical_and(diff > self.tolerance, it < self.max_iter) fn = _vectorized_update if self.vectorized_update else _coordinate_update state = (init_f, jnp.inf, 0) # init, error, iter f_potential, _, _ = jax.lax.while_loop( cond_fun=cond_fn, body_fun=body_fn, init_val=state ) return f_potential
[docs] def init_dual_a( self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), init_f: Optional[jnp.ndarray] = None, ) -> jnp.ndarray: """Apply DualSort algorithm. Args: ot_prob: OT problem between discrete distributions. lse_mode: Return potential if ``True``, scaling if ``False``. rng: Random number generator for stochastic initializers, unused. init_f: potential f, array of size ``[n,]``. This is the starting potential, which is then updated to make the init potential, so an init of an init. Returns: potential/scaling f_u, array of size ``[n,]``. """ del rng assert not ot_prob.geom.is_online, \ "Sorting initializer does not work for online geometry." # check for sorted x, y requires point cloud and could slow initializer cost_matrix = ot_prob.geom.cost_matrix assert cost_matrix.shape[0] == cost_matrix.shape[ 1], "Requires square cost matrix." modified_cost = cost_matrix - jnp.diag(cost_matrix)[None, :] n = cost_matrix.shape[0] init_f = jnp.zeros(n) if init_f is None else init_f f_potential = self._init_sorting_dual(modified_cost, init_f) f_potential = f_potential - jnp.mean(f_potential) return f_potential if lse_mode else ot_prob.geom.scaling_from_potential( f_potential )
def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return ([], { "tolerance": self.tolerance, "max_iter": self.max_iter, "vectorized_update": self.vectorized_update })
[docs]@jax.tree_util.register_pytree_node_class class SubsampleInitializer(DefaultInitializer): """Subsample initializer :cite:`thornton2022rethinking:22`. Subsample each :class:`~ott.geometry.pointcloud.PointCloud`, then compute :class:`Sinkhorn potential <ott.problems.linear.potentials.DualPotentials>` from the subsampled approximations and use this potential to initialize Sinkhorn potentials/scalings for the original problem. Args: subsample_n_x: number of points to subsample from the first measure in :class:`~ott.geometry.pointcloud.PointCloud`. subsample_n_y: number of points to subsample from the second measure in :class:`~ott.geometry.pointcloud.PointCloud`. If ``None``, use ``subsample_n_x``. kwargs: Keyword arguments for :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`. """ def __init__( self, subsample_n_x: int, subsample_n_y: Optional[int] = None, **kwargs: Any, ): super().__init__() self.subsample_n_x = subsample_n_x self.subsample_n_y = subsample_n_y or subsample_n_x self.sinkhorn_kwargs = kwargs
[docs] def init_dual_a( # noqa: D102 self, ot_prob: linear_problem.LinearProblem, lse_mode: bool, rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), ) -> jnp.ndarray: from ott.solvers.linear import sinkhorn assert isinstance( ot_prob.geom, pointcloud.PointCloud ), "Subsample initializer valid only for pointcloud geom." x, y = ot_prob.geom.x, ot_prob.geom.y a, b = ot_prob.a, ot_prob.b # subsample rng_x, rng_y = jax.random.split(rng, 2) sub_x = jax.random.choice( key=rng_x, a=x, shape=(self.subsample_n_x,), replace=True, p=a, axis=0 ) sub_y = jax.random.choice( key=rng_y, a=y, shape=(self.subsample_n_y,), replace=True, p=b, axis=0 ) # create subsampled point cloud geometry sub_geom = pointcloud.PointCloud( sub_x, sub_y, epsilon=ot_prob.geom.epsilon, scale_cost=ot_prob.geom._scale_cost, cost_fn=ot_prob.geom.cost_fn ) # run sinkhorn subsample_sink_out = sinkhorn.solve(sub_geom, **self.sinkhorn_kwargs) # interpolate potentials dual_potentials = subsample_sink_out.to_dual_potentials() f_potential = jax.vmap(dual_potentials.f)(x) return f_potential if lse_mode else ot_prob.geom.scaling_from_potential( f_potential )
def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return ([], { "subsample_n_x": self.subsample_n_x, "subsample_n_y": self.subsample_n_y, **self.sinkhorn_kwargs })
def _vectorized_update( f: jnp.ndarray, modified_cost: jnp.ndarray ) -> jnp.ndarray: """Inner loop DualSort Update. Args: f: potential f, array of size n. modified_cost: cost matrix minus diagonal column-wise. Returns: updated potential vector, f. """ return jnp.min(modified_cost + f[None, :], axis=1) def _coordinate_update( f: jnp.ndarray, modified_cost: jnp.ndarray ) -> jnp.ndarray: """Coordinate-wise updates within inner loop. Args: f: potential f, array of size n. modified_cost: cost matrix minus diagonal column-wise. Returns: updated potential vector, f. """ def body_fn(i: int, f: jnp.ndarray) -> jnp.ndarray: new_f = jnp.min(modified_cost[i, :] + f) return f.at[i].set(new_f) return jax.lax.fori_loop(0, len(f), body_fn, f)