# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Any, Dict, Optional, Sequence, Tuple
import jax
import jax.numpy as jnp
from ott import utils
from ott.geometry import pointcloud
from ott.problems.linear import linear_problem
__all__ = [
"DefaultInitializer", "GaussianInitializer", "SortingInitializer",
"SubsampleInitializer"
]
@jax.tree_util.register_pytree_node_class
class SinkhornInitializer(abc.ABC):
"""Base class for Sinkhorn initializers."""
@abc.abstractmethod
def init_fu(
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
rng: Optional[jax.Array] = None,
) -> jnp.ndarray:
"""Initialize Sinkhorn potential/scaling f_u.
Args:
ot_prob: Linear OT problem.
lse_mode: Return potential if ``True``, scaling if ``False``.
rng: Random number generator for stochastic initializers.
Returns:
potential/scaling, array of size ``[n,]``.
"""
@abc.abstractmethod
def init_gv(
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
rng: Optional[jax.Array] = None,
) -> jnp.ndarray:
"""Initialize Sinkhorn potential/scaling g_v.
Args:
ot_prob: Linear OT problem.
lse_mode: Return potential if ``True``, scaling if ``False``.
rng: Random number generator for stochastic initializers.
Returns:
potential/scaling, array of size ``[m,]``.
"""
def __call__(
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
rng: Optional[jax.Array] = None,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Initialize Sinkhorn potentials/scalings f_u and g_v.
Args:
ot_prob: Linear OT problem.
a: Initial potential/scaling f_u.
If ``None``, it will be initialized using :meth:`init_dual_a`.
b: Initial potential/scaling g_v.
If ``None``, it will be initialized using :meth:`init_dual_b`.
lse_mode: Return potentials if ``True``, scalings if ``False``.
rng: Random number generator for stochastic initializers.
Returns:
The initial potentials/scalings.
"""
rng = utils.default_prng_key(rng)
rng_f, rng_g = jax.random.split(rng, 2)
fu = self.init_fu(ot_prob, lse_mode=lse_mode, rng=rng_f)
gv = self.init_gv(ot_prob, lse_mode=lse_mode, rng=rng_g)
# cancel dual variables for zero weights
mask_value = -jnp.inf if lse_mode else 0.0
fu = jnp.where(ot_prob.a > 0.0, fu, mask_value)
gv = jnp.where(ot_prob.b > 0.0, gv, mask_value)
return fu, gv
def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
return [], {}
@classmethod
def tree_unflatten( # noqa: D102
cls, aux_data: Dict[str, Any], children: Sequence[Any]
) -> "SinkhornInitializer":
return cls(*children, **aux_data)
[docs]
@jax.tree_util.register_pytree_node_class
class DefaultInitializer(SinkhornInitializer):
"""Default initialization of Sinkhorn dual potentials/primal scalings."""
[docs]
def init_fu( # noqa: D102
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
rng: Optional[jax.Array] = None,
) -> jnp.ndarray:
del rng
return jnp.zeros_like(ot_prob.a) if lse_mode else jnp.ones_like(ot_prob.a)
[docs]
def init_gv( # noqa: D102
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
rng: Optional[jax.Array] = None,
) -> jnp.ndarray:
del rng
return jnp.zeros_like(ot_prob.b) if lse_mode else jnp.ones_like(ot_prob.b)
[docs]
@jax.tree_util.register_pytree_node_class
class GaussianInitializer(DefaultInitializer):
"""Gaussian initializer :cite:`thornton2022rethinking:22`.
Compute Gaussian approximations of each
:class:`~ott.geometry.pointcloud.PointCloud`, then compute closed from
Kantorovich potential between Gaussian approximations using Brenier's theorem
(adapt convex/Brenier potential to Kantorovich). Use this Gaussian potential
to initialize Sinkhorn potentials/scalings.
"""
[docs]
def init_fu( # noqa: D102
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
rng: Optional[jax.Array] = None,
) -> jnp.ndarray:
# import Gaussian here due to circular imports
from ott.tools.gaussian_mixture import gaussian
del rng
assert isinstance(
ot_prob.geom, pointcloud.PointCloud
), "Gaussian initializer valid only for pointcloud geoms."
x, y = ot_prob.geom.x, ot_prob.geom.y
a, b = ot_prob.a, ot_prob.b
gaussian_a = gaussian.Gaussian.from_samples(x, weights=a)
gaussian_b = gaussian.Gaussian.from_samples(y, weights=b)
# Brenier potential for cost ||x-y||^2/2, multiply by two for ||x-y||^2
f_potential = 2 * gaussian_a.f_potential(dest=gaussian_b, points=x)
f_potential = f_potential - jnp.mean(f_potential)
return f_potential if lse_mode else ot_prob.geom.scaling_from_potential(
f_potential
)
[docs]
@jax.tree_util.register_pytree_node_class
class SortingInitializer(DefaultInitializer):
"""Sorting initializer :cite:`thornton2022rethinking:22`.
Solve non-regularized OT problem via sorting, then compute potential through
iterated minimum on C-transform and use this potential to initialize
regularized potential.
Args:
vectorized_update: Whether to use vectorized loop.
tolerance: DualSort convergence threshold.
max_iter: Max DualSort steps.
"""
def __init__(
self,
vectorized_update: bool = True,
tolerance: float = 1e-2,
max_iter: int = 100
):
super().__init__()
self.tolerance = tolerance
self.max_iter = max_iter
self.vectorized_update = vectorized_update
def _init_sorting_dual(
self, modified_cost: jnp.ndarray, init_f: jnp.ndarray
) -> jnp.ndarray:
"""Run DualSort algorithm.
Args:
modified_cost: cost matrix minus diagonal column-wise.
init_f: potential f, array of size n. This is the starting potential,
which is then updated to make the init potential, so an init of an init.
Returns:
potential f, array of size n.
"""
def body_fn(
state: Tuple[jnp.ndarray, float, int]
) -> Tuple[jnp.ndarray, float, int]:
prev_f, _, it = state
new_f = fn(prev_f, modified_cost)
diff = jnp.sum((new_f - prev_f) ** 2)
it += 1
return new_f, diff, it
def cond_fn(state: Tuple[jnp.ndarray, float, int]) -> bool:
_, diff, it = state
return jnp.logical_and(diff > self.tolerance, it < self.max_iter)
fn = _vectorized_update if self.vectorized_update else _coordinate_update
state = (init_f, jnp.inf, 0) # init, error, iter
f_potential, _, _ = jax.lax.while_loop(
cond_fun=cond_fn, body_fun=body_fn, init_val=state
)
return f_potential
[docs]
def init_fu(
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
rng: Optional[jax.Array] = None,
init_f: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
"""Apply DualSort algorithm.
Args:
ot_prob: OT problem between discrete distributions.
lse_mode: Return potential if ``True``, scaling if ``False``.
rng: Random number generator for stochastic initializers, unused.
init_f: potential f, array of size ``[n,]``. This is the starting
potential, which is then updated to make the init potential,
so an init of an init.
Returns:
potential/scaling f_u, array of size ``[n,]``.
"""
del rng
assert not ot_prob.geom.is_online, \
"Sorting initializer does not work for online geometry."
# check for sorted x, y requires point cloud and could slow initializer
cost_matrix = ot_prob.geom.cost_matrix
assert cost_matrix.shape[0] == cost_matrix.shape[
1], "Requires square cost matrix."
modified_cost = cost_matrix - jnp.diag(cost_matrix)[None, :]
n = cost_matrix.shape[0]
init_f = jnp.zeros(n, dtype=ot_prob.dtype) if init_f is None else init_f
f_potential = self._init_sorting_dual(modified_cost, init_f)
f_potential = f_potential - jnp.mean(f_potential)
return f_potential if lse_mode else ot_prob.geom.scaling_from_potential(
f_potential
)
def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
return ([], {
"tolerance": self.tolerance,
"max_iter": self.max_iter,
"vectorized_update": self.vectorized_update
})
[docs]
@jax.tree_util.register_pytree_node_class
class SubsampleInitializer(DefaultInitializer):
"""Subsample initializer :cite:`thornton2022rethinking:22`.
Subsample each :class:`~ott.geometry.pointcloud.PointCloud`, then compute
:class:`Sinkhorn potential <ott.problems.linear.potentials.DualPotentials>`
from the subsampled approximations and use this potential to initialize
Sinkhorn potentials/scalings for the original problem.
Args:
subsample_n_x: number of points to subsample from the first measure in
:class:`~ott.geometry.pointcloud.PointCloud`.
subsample_n_y: number of points to subsample from the second measure in
:class:`~ott.geometry.pointcloud.PointCloud`.
If :obj:`None`, use ``subsample_n_x``.
kwargs: Keyword arguments for :func:`~ott.solvers.linear.solve`.
"""
def __init__(
self,
subsample_n_x: int,
subsample_n_y: Optional[int] = None,
**kwargs: Any,
):
super().__init__()
self.subsample_n_x = subsample_n_x
self.subsample_n_y = subsample_n_y or subsample_n_x
self.sinkhorn_kwargs = kwargs
[docs]
def init_fu( # noqa: D102
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
rng: Optional[jax.Array] = None,
) -> jnp.ndarray:
from ott.solvers import linear
assert isinstance(
ot_prob.geom, pointcloud.PointCloud
), "Subsample initializer valid only for pointcloud geom."
rng = utils.default_prng_key(rng)
rng_x, rng_y = jax.random.split(rng, 2)
x, y = ot_prob.geom.x, ot_prob.geom.y
a, b = ot_prob.a, ot_prob.b
# subsample
sub_x = jax.random.choice(
rng_x, a=x, shape=(self.subsample_n_x,), replace=True, p=a, axis=0
)
sub_y = jax.random.choice(
rng_y, a=y, shape=(self.subsample_n_y,), replace=True, p=b, axis=0
)
# create subsampled point cloud geometry
sub_geom = pointcloud.PointCloud(
sub_x,
sub_y,
epsilon=ot_prob.geom.epsilon,
scale_cost=ot_prob.geom._scale_cost,
cost_fn=ot_prob.geom.cost_fn
)
# run sinkhorn
subsample_sink_out = linear.solve(sub_geom, **self.sinkhorn_kwargs)
# interpolate potentials
dual_potentials = subsample_sink_out.to_dual_potentials()
f_potential = jax.vmap(dual_potentials.f)(x)
return f_potential if lse_mode else ot_prob.geom.scaling_from_potential(
f_potential
)
def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102
return ([], {
"subsample_n_x": self.subsample_n_x,
"subsample_n_y": self.subsample_n_y,
**self.sinkhorn_kwargs
})
def _vectorized_update(
f: jnp.ndarray, modified_cost: jnp.ndarray
) -> jnp.ndarray:
"""Inner loop DualSort Update.
Args:
f: potential f, array of size n.
modified_cost: cost matrix minus diagonal column-wise.
Returns:
updated potential vector, f.
"""
return jnp.min(modified_cost + f[None, :], axis=1)
def _coordinate_update(
f: jnp.ndarray, modified_cost: jnp.ndarray
) -> jnp.ndarray:
"""Coordinate-wise updates within inner loop.
Args:
f: potential f, array of size n.
modified_cost: cost matrix minus diagonal column-wise.
Returns:
updated potential vector, f.
"""
def body_fn(i: int, f: jnp.ndarray) -> jnp.ndarray:
new_f = jnp.min(modified_cost[i, :] + f)
return f.at[i].set(new_f)
return jax.lax.fori_loop(0, len(f), body_fn, f)