Source code for ott.initializers.linear.initializers_lr

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import functools
from typing import (
    TYPE_CHECKING,
    Any,
    Dict,
    Literal,
    Mapping,
    NamedTuple,
    Optional,
    Sequence,
    Tuple,
    Union,
)

import jax
import jax.numpy as jnp
import numpy as np

from ott import utils
from ott.geometry import geometry, low_rank, pointcloud
from ott.math import fixed_point_loop
from ott.math import utils as mu

if TYPE_CHECKING:
  from ott.problems.linear import linear_problem
  from ott.problems.quadratic import quadratic_problem
  from ott.solvers.linear import sinkhorn, sinkhorn_lr
  from ott.solvers.quadratic import gromov_wasserstein_lr

Problem_t = Union["linear_problem.LinearProblem",
                  "quadratic_problem.QuadraticProblem"]

__all__ = [
    "RandomInitializer", "Rank2Initializer", "KMeansInitializer",
    "GeneralizedKMeansInitializer"
]


[docs] @jax.tree_util.register_pytree_node_class class LRInitializer(abc.ABC): """Base class for low-rank initializers. Args: rank: Rank of the factorization. kwargs: Additional keyword arguments. """ def __init__(self, rank: int, **kwargs: Any): self._rank = rank self._kwargs = kwargs
[docs] @abc.abstractmethod def init_q( self, ot_prob: Problem_t, rng: jax.Array, *, init_g: jnp.ndarray, **kwargs: Any, ) -> jnp.ndarray: """Initialize the low-rank factor :math:`Q`. Args: ot_prob: OT problem. rng: Random key for seeding. init_g: Initial value for :math:`g` factor. kwargs: Additional keyword arguments. Returns: Array of shape ``[n, rank]``. """
[docs] @abc.abstractmethod def init_r( self, ot_prob: Problem_t, rng: jax.Array, *, init_g: jnp.ndarray, **kwargs: Any, ) -> jnp.ndarray: """Initialize the low-rank factor :math:`R`. Args: ot_prob: Linear OT problem. rng: Random key for seeding. init_g: Initial value for :math:`g` factor. kwargs: Additional keyword arguments. Returns: Array of shape ``[m, rank]``. """
[docs] @abc.abstractmethod def init_g( self, ot_prob: Problem_t, rng: jax.Array, **kwargs: Any, ) -> jnp.ndarray: """Initialize the low-rank factor :math:`g`. Args: ot_prob: OT problem. rng: Random key for seeding. kwargs: Additional keyword arguments. Returns: Array of shape ``[rank,]``. """
[docs] @classmethod def from_solver( cls, solver: Union["sinkhorn_lr.LRSinkhorn", "gromov_wasserstein_lr.LRGromovWasserstein"], *, kind: Literal["random", "rank2", "k-means", "generalized-k-means"], **kwargs: Any, ) -> "LRInitializer": """Create a low-rank initializer from a linear or quadratic solver. Args: solver: Low-rank linear or quadratic solver. kind: Which initializer to instantiate. kwargs: Keyword arguments when creating the initializer. Returns: Low-rank initializer. """ rank = solver.rank sinkhorn_kwargs = { "norm_error": solver._norm_error, "lse_mode": solver.lse_mode, "implicit_diff": solver.implicit_diff, "use_danskin": solver.use_danskin } if kind == "random": return RandomInitializer(rank, **kwargs) if kind == "rank2": return Rank2Initializer(rank, **kwargs) if kind == "k-means": return KMeansInitializer(rank, sinkhorn_kwargs=sinkhorn_kwargs, **kwargs) if kind == "generalized-k-means": return GeneralizedKMeansInitializer( rank, sinkhorn_kwargs=sinkhorn_kwargs, **kwargs ) raise NotImplementedError(f"Initializer `{kind}` is not implemented.")
def __call__( self, ot_prob: Problem_t, q: Optional[jnp.ndarray] = None, r: Optional[jnp.ndarray] = None, g: Optional[jnp.ndarray] = None, *, rng: Optional[jax.Array] = None, **kwargs: Any ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Initialize the factors :math:`Q`, :math:`R` and :math:`g`. Args: ot_prob: OT problem. q: Factor of shape ``[n, rank]``. If `None`, it will be initialized using :meth:`init_q`. r: Factor of shape ``[m, rank]``. If `None`, it will be initialized using :meth:`init_r`. g: Factor of shape ``[rank,]``. If `None`, it will be initialized using :meth:`init_g`. rng: Random key for seeding. kwargs: Additional keyword arguments for :meth:`init_q`, :meth:`init_r` and :meth:`init_g`. Returns: The factors :math:`Q`, :math:`R` and :math:`g`, respectively. """ rng = utils.default_prng_key(rng) rng1, rng2, rng3 = jax.random.split(rng, 3) if g is None: g = self.init_g(ot_prob, rng1, **kwargs) if q is None: q = self.init_q(ot_prob, rng2, init_g=g, **kwargs) if r is None: r = self.init_r(ot_prob, rng3, init_g=g, **kwargs) assert g.shape == (self.rank,) assert q.shape == (ot_prob.a.shape[0], self.rank) assert r.shape == (ot_prob.b.shape[0], self.rank) return q, r, g @property def rank(self) -> int: """Rank of the transport matrix factorization.""" return self._rank def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return [], {**self._kwargs, "rank": self.rank} @classmethod def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "LRInitializer": return cls(*children, **aux_data)
[docs] @jax.tree_util.register_pytree_node_class class RandomInitializer(LRInitializer): """Low-rank Sinkhorn factorization using random factors. Args: rank: Rank of the factorization. kwargs: Additional keyword arguments. """
[docs] def init_q( # noqa: D102 self, ot_prob: Problem_t, rng: jax.Array, *, init_g: jnp.ndarray, **kwargs: Any, ) -> jnp.ndarray: del kwargs, init_g a = ot_prob.a init_q = jnp.abs(jax.random.normal(rng, (a.shape[0], self.rank))) return a[:, None] * (init_q / jnp.sum(init_q, axis=1, keepdims=True))
[docs] def init_r( # noqa: D102 self, ot_prob: Problem_t, rng: jax.Array, *, init_g: jnp.ndarray, **kwargs: Any, ) -> jnp.ndarray: del kwargs, init_g b = ot_prob.b init_r = jnp.abs(jax.random.normal(rng, (b.shape[0], self.rank))) return b[:, None] * (init_r / jnp.sum(init_r, axis=1, keepdims=True))
[docs] def init_g( # noqa: D102 self, ot_prob: Problem_t, rng: jax.Array, **kwargs: Any, ) -> jnp.ndarray: del kwargs init_g = jnp.abs(jax.random.uniform(rng, (self.rank,))) + 1.0 return init_g / jnp.sum(init_g)
[docs] @jax.tree_util.register_pytree_node_class class Rank2Initializer(LRInitializer): """Low-rank Sinkhorn factorization using rank-2 factors :cite:`scetbon:21`. Args: rank: Rank of the factorization. kwargs: Additional keyword arguments. """ def _compute_factor( self, ot_prob: Problem_t, init_g: jnp.ndarray, *, which: Literal["q", "r"], ) -> jnp.ndarray: a, b = ot_prob.a, ot_prob.b marginal = a if which == "q" else b n, r = marginal.shape[0], self.rank lambda_1 = jnp.min( jnp.array([jnp.min(a), jnp.min(init_g), jnp.min(b)]) ) * 0.5 g1 = jnp.arange(1, r + 1) g1 /= g1.astype(float).sum() g2 = (init_g - lambda_1 * g1) / (1.0 - lambda_1) x = jnp.arange(1, n + 1) x /= x.astype(float).sum() y = (marginal - lambda_1 * x) / (1.0 - lambda_1) return ((lambda_1 * x[:, None] @ g1.reshape(1, -1)) + ((1.0 - lambda_1) * y[:, None] @ g2.reshape(1, -1)))
[docs] def init_q( # noqa: D102 self, ot_prob: Problem_t, rng: jax.Array, *, init_g: jnp.ndarray, **kwargs: Any, ) -> jnp.ndarray: del rng, kwargs return self._compute_factor(ot_prob, init_g, which="q")
[docs] def init_r( # noqa: D102 self, ot_prob: Problem_t, rng: jax.Array, *, init_g: jnp.ndarray, **kwargs: Any, ) -> jnp.ndarray: del rng, kwargs return self._compute_factor(ot_prob, init_g, which="r")
[docs] def init_g( # noqa: D102 self, ot_prob: Problem_t, rng: jax.Array, **kwargs: Any, ) -> jnp.ndarray: del rng, kwargs return jnp.ones((self.rank,)) / self.rank
[docs] @jax.tree_util.register_pytree_node_class class KMeansInitializer(LRInitializer): """K-means initializer for low-rank Sinkhorn :cite:`scetbon:22b`. Applicable for :class:`~ott.geometry.pointcloud.PointCloud` and :class:`~ott.geometry.low_rank.LRCGeometry`. Args: rank: Rank of the factorization. min_iterations: Minimum number of k-means iterations. max_iterations: Maximum number of k-means iterations. sinkhorn_kwargs: Keyword arguments for :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`. kwargs: Keyword arguments for :func:`~ott.tools.k_means.k_means`. """ def __init__( self, rank: int, min_iterations: int = 100, max_iterations: int = 100, sinkhorn_kwargs: Optional[Mapping[str, Any]] = None, **kwargs: Any ): super().__init__(rank, **kwargs) self._min_iter = min_iterations self._max_iter = max_iterations self._sinkhorn_kwargs = {} if sinkhorn_kwargs is None else sinkhorn_kwargs @staticmethod def _extract_array(geom: geometry.Geometry, *, first: bool) -> jnp.ndarray: if isinstance(geom, pointcloud.PointCloud): return geom.x if first else geom.y if isinstance(geom, low_rank.LRCGeometry): return geom.cost_1 if first else geom.cost_2 raise TypeError( f"k-means initializer not implemented for `{type(geom).__name__}`." ) def _compute_factor( self, ot_prob: Problem_t, rng: jax.Array, *, init_g: jnp.ndarray, which: Literal["q", "r"], **kwargs: Any, ) -> jnp.ndarray: from ott.problems.linear import linear_problem from ott.problems.quadratic import quadratic_problem from ott.solvers.linear import sinkhorn from ott.tools import k_means del kwargs fn = functools.partial( k_means.k_means, min_iterations=self._min_iter, max_iterations=self._max_iter, **self._kwargs ) if isinstance(ot_prob, quadratic_problem.QuadraticProblem): if ot_prob.geom_xy is not None and ot_prob.fused_penalty >= 1.0: # prefer the linear term if it has a higher weight geom = ot_prob.geom_xy else: geom = ot_prob.geom_xx if which == "q" else ot_prob.geom_yy else: geom = ot_prob.geom arr = self._extract_array(geom, first=which == "q") marginals = ot_prob.a if which == "q" else ot_prob.b centroids = fn(arr, self.rank, rng=rng).centroids geom = pointcloud.PointCloud( arr, centroids, epsilon=1e-1, scale_cost="max_cost" ) prob = linear_problem.LinearProblem(geom, marginals, init_g) solver = sinkhorn.Sinkhorn(**self._sinkhorn_kwargs) return solver(prob).matrix
[docs] def init_q( # noqa: D102 self, ot_prob: Problem_t, rng: jax.Array, *, init_g: jnp.ndarray, **kwargs: Any, ) -> jnp.ndarray: return self._compute_factor( ot_prob, rng, init_g=init_g, which="q", **kwargs )
[docs] def init_r( # noqa: D102 self, ot_prob: Problem_t, rng: jax.Array, *, init_g: jnp.ndarray, **kwargs: Any, ) -> jnp.ndarray: return self._compute_factor( ot_prob, rng, init_g=init_g, which="r", **kwargs )
[docs] def init_g( # noqa: D102 self, ot_prob: Problem_t, rng: jax.Array, **kwargs: Any, ) -> jnp.ndarray: del rng, kwargs return jnp.ones((self.rank,)) / self.rank
def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 children, aux_data = super().tree_flatten() aux_data["sinkhorn_kwargs"] = self._sinkhorn_kwargs aux_data["min_iterations"] = self._min_iter aux_data["max_iterations"] = self._max_iter return children, aux_data
[docs] class GeneralizedKMeansInitializer(KMeansInitializer): """Generalized k-means initializer :cite:`scetbon:22b`. Applicable for any :class:`~ott.geometry.geometry.Geometry` with a square shape. Args: rank: Rank of the factorization. gamma: The (inverse of) gradient step size used by mirror descent. min_iterations: Minimum number of iterations. max_iterations: Maximum number of iterations. inner_iterations: Number of iterations used by the algorithm before re-evaluating progress. threshold: Convergence threshold. sinkhorn_kwargs: Keyword arguments for :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`. """ def __init__( self, rank: int, gamma: float = 10.0, min_iterations: int = 0, max_iterations: int = 100, inner_iterations: int = 10, threshold: float = 1e-6, sinkhorn_kwargs: Optional[Mapping[str, Any]] = None, ): super().__init__( rank, sinkhorn_kwargs=sinkhorn_kwargs, # below argument are stored in `_kwargs` gamma=gamma, min_iterations=min_iterations, max_iterations=max_iterations, inner_iterations=inner_iterations, threshold=threshold, ) class Constants(NamedTuple): # noqa: D106 solver: "sinkhorn.Sinkhorn" geom: geometry.Geometry # (n, n) marginal: jnp.ndarray # (n,) g: jnp.ndarray # (r,) gamma: float threshold: float class State(NamedTuple): # noqa: D106 factor: jnp.ndarray criterions: jnp.ndarray crossed_threshold: bool def _compute_factor( self, ot_prob: Problem_t, rng: jax.Array, *, init_g: jnp.ndarray, which: Literal["q", "r"], **kwargs: Any, ) -> jnp.ndarray: from ott.problems.linear import linear_problem from ott.problems.quadratic import quadratic_problem from ott.solvers.linear import sinkhorn def init_fn() -> GeneralizedKMeansInitializer.State: n = geom.shape[0] factor = jnp.abs(jax.random.normal(rng, (n, self.rank))) + 1.0 # (n, r) factor *= consts.marginal[:, None] / jnp.sum( factor, axis=1, keepdims=True ) return self.State( factor, criterions=-jnp.ones(outer_iterations), crossed_threshold=False ) # see the explanation in `ott.solvers.linear.sinkhorn_lr` def converged( state: GeneralizedKMeansInitializer.State, consts: GeneralizedKMeansInitializer.Constants, iteration: int ) -> bool: def conv_crossed(prev_err: float, curr_err: float) -> bool: return jnp.logical_and( prev_err < consts.threshold, curr_err < consts.threshold ) def conv_not_crossed(prev_err: float, curr_err: float) -> bool: return jnp.logical_and(curr_err < prev_err, curr_err < consts.threshold) it = iteration // inner_iterations return jax.lax.cond( state.crossed_threshold, conv_crossed, conv_not_crossed, state.criterions[it - 2], state.criterions[it - 1] ) def diverged( state: GeneralizedKMeansInitializer.State, iteration: int ) -> bool: it = iteration // inner_iterations return jnp.logical_not(jnp.isfinite(state.criterions[it - 1])) def cond_fn( iteration: int, consts: GeneralizedKMeansInitializer.Constants, state: GeneralizedKMeansInitializer.State, ) -> bool: return jnp.logical_or( iteration <= 2, jnp.logical_and( jnp.logical_not(diverged(state, iteration)), jnp.logical_not(converged(state, consts, iteration)) ) ) def body_fn( iteration: int, consts: GeneralizedKMeansInitializer.Constants, state: GeneralizedKMeansInitializer.State, compute_error: bool ) -> GeneralizedKMeansInitializer.State: del compute_error it = iteration // inner_iterations grad = consts.geom.apply_cost(state.factor, axis=1) # (n, r) grad = grad + consts.geom.apply_cost(state.factor, axis=0) # (n, r) grad = grad / consts.g norm = jnp.max(jnp.abs(grad)) ** 2 gamma = consts.gamma / norm eps = 1.0 / gamma cost = grad - eps * mu.safe_log(state.factor) # (n, r) cost = geometry.Geometry( cost_matrix=cost, epsilon=eps, ) problem = linear_problem.LinearProblem( cost, a=consts.marginal, b=consts.g ) out = consts.solver(problem) new_factor = out.matrix criterion = ((1 / gamma) ** 2) * ( mu.kl(new_factor, state.factor) + mu.kl(state.factor, new_factor) ) crossed_threshold = jnp.logical_or( state.crossed_threshold, jnp.logical_and( state.criterions[it - 1] >= consts.threshold, criterion < consts.threshold ) ) return self.State( factor=new_factor, criterions=state.criterions.at[it].set(criterion), crossed_threshold=crossed_threshold ) del kwargs if isinstance(ot_prob, quadratic_problem.QuadraticProblem): geom = ot_prob.geom_xx if which == "q" else ot_prob.geom_yy else: geom = ot_prob.geom assert geom.shape[0] == geom.shape[ 1], f"Expected the shape to be square, found `{geom.shape}`." inner_iterations = self._kwargs["inner_iterations"] outer_iterations = np.ceil(self._max_iter / inner_iterations).astype(int) force_scan = self._min_iter == self._max_iter fixpoint_fn = ( fixed_point_loop.fixpoint_iter if force_scan else fixed_point_loop.fixpoint_iter_backprop ) consts = self.Constants( solver=sinkhorn.Sinkhorn(**self._sinkhorn_kwargs), geom=geom.set_scale_cost("max_cost"), marginal=ot_prob.a if which == "q" else ot_prob.b, g=init_g, gamma=self._kwargs["gamma"], threshold=self._kwargs["threshold"], ) return fixpoint_fn( cond_fn, body_fn, min_iterations=self._min_iter, max_iterations=self._max_iter, inner_iterations=inner_iterations, constants=consts, state=init_fn(), ).factor