# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Mapping, NamedTuple, Optional, Tuple
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np
from ott.geometry import geometry
from ott.initializers.linear import initializers_lr
from ott.math import fixed_point_loop
from ott.math import utils as mu
from ott.problems.linear import linear_problem
from ott.solvers.linear import lr_utils, sinkhorn
__all__ = ["LRSinkhorn", "LRSinkhornOutput"]
ProgressFunction = Callable[
[Tuple[np.ndarray, np.ndarray, np.ndarray, "LRSinkhornState"]], None]
[docs]
class LRSinkhornState(NamedTuple):
"""State of the Low Rank Sinkhorn algorithm."""
q: jnp.ndarray
r: jnp.ndarray
g: jnp.ndarray
gamma: float
costs: jnp.ndarray
errors: jnp.ndarray
crossed_threshold: bool
[docs]
def compute_error( # noqa: D102
self, previous_state: "LRSinkhornState"
) -> float:
err_q = mu.gen_js(self.q, previous_state.q, c=1.0)
err_r = mu.gen_js(self.r, previous_state.r, c=1.0)
err_g = mu.gen_js(self.g, previous_state.g, c=1.0)
# don't scale by (1 / gamma ** 2); https://github.com/ott-jax/ott/pull/547
return err_q + err_r + err_g
[docs]
def reg_ot_cost( # noqa: D102
self,
ot_prob: linear_problem.LinearProblem,
*,
epsilon: float,
use_danskin: bool = True
) -> float:
"""For LR Sinkhorn, this defaults to the primal cost of LR solution."""
return compute_reg_ot_cost(
self.q,
self.r,
self.g,
ot_prob,
epsilon=epsilon,
use_danskin=use_danskin
)
[docs]
def solution_error( # noqa: D102
self, ot_prob: linear_problem.LinearProblem, norm_error: Tuple[int, ...]
) -> jnp.ndarray:
return solution_error(self.q, self.r, ot_prob, norm_error)
[docs]
def set(self, **kwargs: Any) -> "LRSinkhornState":
"""Return a copy of self, with potential overwrites."""
return self._replace(**kwargs)
def compute_reg_ot_cost(
q: jnp.ndarray,
r: jnp.ndarray,
g: jnp.ndarray,
ot_prob: linear_problem.LinearProblem,
epsilon: float,
use_danskin: bool = True
) -> float:
"""Compute the regularized OT cost, here the primal cost of the LR solution.
Args:
q: first factor of solution
r: second factor of solution
g: weights of solution
ot_prob: linear problem
epsilon: Entropic regularization.
use_danskin: if True, use Danskin's theorem :cite:`danskin:67,bertsekas:71`
to avoid having to differentiate the three factors ``q``, ``r`` and ``g``
w.r.t. relevant quantities.
Returns:
regularized OT cost, the (primal) transport cost of the low-rank solution.
"""
tau_a, tau_b = ot_prob.tau_a, ot_prob.tau_b
q = jax.lax.stop_gradient(q) if use_danskin else q
r = jax.lax.stop_gradient(r) if use_danskin else r
g = jax.lax.stop_gradient(g) if use_danskin else g
cost = jnp.sum(ot_prob.geom.apply_cost(r, axis=1) * q * (1.0 / g)[None, :])
cost -= epsilon * (mu.gen_ent(q) + mu.gen_ent(r) + mu.gen_ent(g))
if tau_a != 1.0:
cost += tau_a / (1.0 - tau_a) * mu.gen_kl(jnp.sum(q, axis=1), ot_prob.a)
if tau_b != 1.0:
cost += tau_b / (1.0 - tau_b) * mu.gen_kl(jnp.sum(r, axis=1), ot_prob.b)
return cost
def solution_error(
q: jnp.ndarray, r: jnp.ndarray, ot_prob: linear_problem.LinearProblem,
norm_error: Tuple[int, ...]
) -> jnp.ndarray:
"""Compute solution error.
Since only balanced case is available for LR, this is marginal deviation.
Args:
q: first factor of solution.
r: second factor of solution.
ot_prob: linear problem.
norm_error: int, p-norm used to compute error.
Returns:
one or possibly many numbers quantifying deviation to true marginals.
"""
norm_error = jnp.array(norm_error)
# Update the error
err = jnp.sum(
jnp.abs(jnp.sum(q, axis=1) - ot_prob.a) ** norm_error[:, None], axis=1
) ** (1.0 / norm_error)
err += jnp.sum(
jnp.abs(jnp.sum(r, axis=1) - ot_prob.b) ** norm_error[:, None], axis=1
) ** (1.0 / norm_error)
err += jnp.sum(
jnp.abs(jnp.sum(q, axis=0) - jnp.sum(r, axis=0)) ** norm_error[:, None],
axis=1
) ** (1.0 / norm_error)
return err
[docs]
class LRSinkhornOutput(NamedTuple):
"""Transport interface for a low-rank Sinkhorn solution."""
q: jnp.ndarray
r: jnp.ndarray
g: jnp.ndarray
costs: jnp.ndarray
# TODO(michalk8): must be called `errors`, because of `store_inner_errors`
# in future, enforce via class hierarchy
errors: jnp.ndarray
ot_prob: linear_problem.LinearProblem
epsilon: float
inner_iterations: int
converged: bool
# TODO(michalk8): Optional is an artifact of the current impl., refactor
reg_ot_cost: Optional[float] = None
[docs]
def set(self, **kwargs: Any) -> "LRSinkhornOutput":
"""Return a copy of self, with potential overwrites."""
return self._replace(**kwargs)
[docs]
def set_cost( # noqa: D102
self,
ot_prob: linear_problem.LinearProblem,
lse_mode: bool,
use_danskin: bool = True
) -> "LRSinkhornOutput":
del lse_mode
return self.set(reg_ot_cost=self.compute_reg_ot_cost(ot_prob, use_danskin))
[docs]
def compute_reg_ot_cost( # noqa: D102
self,
ot_prob: linear_problem.LinearProblem,
use_danskin: bool = True,
) -> float:
return compute_reg_ot_cost(
self.q,
self.r,
self.g,
ot_prob,
epsilon=self.epsilon,
use_danskin=use_danskin
)
@property
def geom(self) -> geometry.Geometry: # noqa: D102
return self.ot_prob.geom
@property
def a(self) -> jnp.ndarray: # noqa: D102
return self.ot_prob.a
@property
def b(self) -> jnp.ndarray: # noqa: D102
return self.ot_prob.b
@property
def n_iters(self) -> int: # noqa: D102
return jnp.sum(self.errors != -1) * self.inner_iterations
@property
def matrix(self) -> jnp.ndarray:
"""Transport matrix if it can be instantiated."""
return (self.q * self._inv_g) @ self.r.T
[docs]
def apply(self, inputs: jnp.ndarray, axis: int = 0) -> jnp.ndarray:
"""Apply the transport to a array; axis=1 for its transpose."""
q, r = (self.q, self.r) if axis == 1 else (self.r, self.q)
# for `axis=0`: (batch, m), (m, r), (r,), (r, n)
return ((inputs @ r) * self._inv_g) @ q.T
[docs]
def marginal(self, axis: int) -> jnp.ndarray: # noqa: D102
length = self.q.shape[0] if axis == 0 else self.r.shape[0]
return self.apply(jnp.ones(length,), axis=axis)
[docs]
def cost_at_geom(self, other_geom: geometry.Geometry) -> float:
"""Return OT cost for current solution, evaluated at any cost matrix."""
return jnp.sum(self.q * other_geom.apply_cost(self.r, axis=1) * self._inv_g)
[docs]
def transport_cost_at_geom(self, other_geom: geometry.Geometry) -> float:
"""Return (by recomputing it) bare transport cost of current solution."""
return self.cost_at_geom(other_geom)
@property
def primal_cost(self) -> float:
"""Return (by recomputing it) transport cost of current solution."""
return self.transport_cost_at_geom(other_geom=self.geom)
@property
def transport_mass(self) -> float:
"""Sum of transport matrix."""
return self.marginal(0).sum()
@property
def _inv_g(self) -> jnp.ndarray:
return 1.0 / self.g
[docs]
@jax.tree_util.register_pytree_node_class
class LRSinkhorn(sinkhorn.Sinkhorn):
r"""Low-Rank Sinkhorn solver for linear reg-OT problems.
The algorithm tries to minimize the :term:`low-rank optimal transport`
problem, a constrained formulation of the :term:`Kantorovich problem` where
the :term:`coupling` variable is constrained to have a low-rank.
That problem is non-convex, and therefore any algorithm that tries to
solve it requires special attention to initialization and control of
convergence. Convergence is evaluated on successive evaluations of the
objective whereas initializers are instance of the
:class:`~ott.ott.initializers.linear.initializers_lr.LRInitializer` class.
The algorithm is described in :cite:`scetbon:21` and the implementation
contained here is adapted from `LOT <https://github.com/meyerscetbon/LOT>`_.
Args:
rank: Rank constraint on the coupling to minimize the linear OT problem
gamma: The (inverse of) gradient step size used by mirror descent.
gamma_rescale: Whether to rescale :math:`\gamma` every iteration as
described in :cite:`scetbon:22b`.
epsilon: Entropic regularization added on top of low-rank problem.
initializer: How to initialize the :math:`Q`, :math:`R` and :math:`g`
factors.
lse_mode: Whether to run computations in LSE or kernel mode.
inner_iterations: Number of inner iterations used by the algorithm before
re-evaluating progress.
use_danskin: Use Danskin theorem to evaluate gradient of objective w.r.t.
input parameters. Only `True` handled at this moment.
progress_fn: callback function which gets called during the Sinkhorn
iterations, so the user can display the error at each iteration,
e.g., using a progress bar. See :func:`~ott.utils.default_progress_fn`
for a basic implementation.
kwargs_dys: Keyword arguments passed to :meth:`dykstra_update_lse`,
:meth:`dykstra_update_kernel` or one of the functions defined in
:mod:`ott.solvers.linear`, depending on whether the problem
is balanced and on the ``lse_mode``.
kwargs: Keyword arguments for
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn`.
"""
def __init__(
self,
rank: int,
gamma: float = 10.0,
gamma_rescale: bool = True,
epsilon: float = 0.0,
initializer: Optional[initializers_lr.LRInitializer] = None,
lse_mode: bool = True,
inner_iterations: int = 10,
use_danskin: bool = True,
kwargs_dys: Optional[Mapping[str, Any]] = None,
progress_fn: Optional[ProgressFunction] = None,
**kwargs: Any,
):
kwargs["implicit_diff"] = None # not yet implemented
super().__init__(
lse_mode=lse_mode,
inner_iterations=inner_iterations,
use_danskin=use_danskin,
**kwargs
)
self.rank = rank
self.gamma = gamma
self.gamma_rescale = gamma_rescale
self.epsilon = epsilon
self.initializer = initializers_lr.RandomInitializer(
rank
) if initializer is None else initializer
self.progress_fn = progress_fn
self.kwargs_dys = {} if kwargs_dys is None else kwargs_dys
def __call__(
self,
ot_prob: linear_problem.LinearProblem,
init: Optional[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]] = None,
**kwargs: Any,
) -> LRSinkhornOutput:
"""Run low-rank Sinkhorn.
Args:
ot_prob: Linear OT problem.
init: Initial values for the low-rank factors:
- :attr:`~ott.solvers.linear.sinkhorn_lr.LRSinkhornOutput.q`.
- :attr:`~ott.solvers.linear.sinkhorn_lr.LRSinkhornOutput.r`.
- :attr:`~ott.solvers.linear.sinkhorn_lr.LRSinkhornOutput.g`.
If :obj:`None`, run the initializer.
kwargs: Keyword arguments for the initializer.
Returns:
The low-rank Sinkhorn output.
"""
if init is None:
init = self.initializer(ot_prob, **kwargs)
return run(ot_prob, self, init)
def _get_costs(
self,
ot_prob: linear_problem.LinearProblem,
state: LRSinkhornState,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, float]:
log_q, log_r, log_g = (
mu.safe_log(state.q), mu.safe_log(state.r), mu.safe_log(state.g)
)
inv_g = 1.0 / state.g[None, :]
tmp = ot_prob.geom.apply_cost(state.r, axis=1)
grad_q = tmp * inv_g
grad_r = ot_prob.geom.apply_cost(state.q) * inv_g
grad_g = -jnp.sum(state.q * tmp, axis=0) / (state.g ** 2)
grad_q += self.epsilon * log_q
grad_r += self.epsilon * log_r
grad_g += self.epsilon * log_g
if self.gamma_rescale:
norm_q = jnp.max(jnp.abs(grad_q)) ** 2
norm_r = jnp.max(jnp.abs(grad_r)) ** 2
norm_g = jnp.max(jnp.abs(grad_g)) ** 2
gamma = self.gamma / jnp.max(jnp.array([norm_q, norm_r, norm_g]))
else:
gamma = self.gamma
eps_factor = 1.0 / (self.epsilon * gamma + 1.0)
gamma *= eps_factor
c_q = -gamma * grad_q + eps_factor * log_q
c_r = -gamma * grad_r + eps_factor * log_r
c_g = -gamma * grad_g + eps_factor * log_g
return c_q, c_r, c_g, gamma
# TODO(michalk8): move to `lr_utils` when refactoring this
[docs]
def dykstra_update_lse(
self,
c_q: jnp.ndarray,
c_r: jnp.ndarray,
h: jnp.ndarray,
gamma: float,
ot_prob: linear_problem.LinearProblem,
min_entry_value: float = 1e-6,
tolerance: float = 1e-3,
min_iter: int = 0,
inner_iter: int = 10,
max_iter: int = 10000
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Run Dykstra's algorithm."""
# shortcuts for problem's definition.
r = self.rank
n, m = ot_prob.geom.shape
loga, logb = jnp.log(ot_prob.a), jnp.log(ot_prob.b)
h_old = h
g1_old, g2_old = jnp.zeros(r), jnp.zeros(r)
f1, f2 = jnp.zeros(n), jnp.zeros(m)
w_gi, w_gp = jnp.zeros(r), jnp.zeros(r)
w_q, w_r = jnp.zeros(r), jnp.zeros(r)
err = jnp.inf
state_inner = f1, f2, g1_old, g2_old, h_old, w_gi, w_gp, w_q, w_r, err
constants = c_q, c_r, loga, logb
def cond_fn(
iteration: int, constants: Tuple[jnp.ndarray, ...],
state_inner: Tuple[jnp.ndarray, ...]
) -> bool:
del iteration, constants
*_, err = state_inner
return err > tolerance
def _softm(
f: jnp.ndarray, g: jnp.ndarray, c: jnp.ndarray, axis: int
) -> jnp.ndarray:
return jsp.special.logsumexp(
gamma * (f[:, None] + g[None, :] - c), axis=axis
)
def body_fn(
iteration: int, constants: Tuple[jnp.ndarray, ...],
state_inner: Tuple[jnp.ndarray, ...], compute_error: bool
) -> Tuple[jnp.ndarray, ...]:
# TODO(michalk8): in the future, use `NamedTuple`
f1, f2, g1_old, g2_old, h_old, w_gi, w_gp, w_q, w_r, err = state_inner
c_q, c_r, loga, logb = constants
# First Projection
f1 = jnp.where(
jnp.isfinite(loga),
(loga - _softm(f1, g1_old, c_q, axis=1)) / gamma + f1, loga
)
f2 = jnp.where(
jnp.isfinite(logb),
(logb - _softm(f2, g2_old, c_r, axis=1)) / gamma + f2, logb
)
h = h_old + w_gi
h = jnp.maximum(jnp.log(min_entry_value) / gamma, h)
w_gi += h_old - h
h_old = h
# Update couplings
g_q = _softm(f1, g1_old, c_q, axis=0)
g_r = _softm(f2, g2_old, c_r, axis=0)
# Second Projection
h = (1.0 / 3.0) * (h_old + w_gp + w_q + w_r)
h += g_q / (3.0 * gamma)
h += g_r / (3.0 * gamma)
g1 = h + g1_old - g_q / gamma
g2 = h + g2_old - g_r / gamma
w_q = w_q + g1_old - g1
w_r = w_r + g2_old - g2
w_gp = h_old + w_gp - h
q, r, _ = recompute_couplings(f1, g1, c_q, f2, g2, c_r, h, gamma)
g1_old = g1
g2_old = g2
h_old = h
err = jax.lax.cond(
jnp.logical_and(compute_error, iteration >= min_iter),
lambda: solution_error(q, r, ot_prob, self.norm_error)[0], lambda: err
)
return f1, f2, g1_old, g2_old, h_old, w_gi, w_gp, w_q, w_r, err
def recompute_couplings(
f1: jnp.ndarray,
g1: jnp.ndarray,
c_q: jnp.ndarray,
f2: jnp.ndarray,
g2: jnp.ndarray,
c_r: jnp.ndarray,
h: jnp.ndarray,
gamma: float,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
q = jnp.exp(gamma * (f1[:, None] + g1[None, :] - c_q))
r = jnp.exp(gamma * (f2[:, None] + g2[None, :] - c_r))
g = jnp.exp(gamma * h)
return q, r, g
state_inner = fixed_point_loop.fixpoint_iter_backprop(
cond_fn, body_fn, min_iter, max_iter, inner_iter, constants, state_inner
)
f1, f2, g1_old, g2_old, h_old, _, _, _, _, _ = state_inner
return recompute_couplings(f1, g1_old, c_q, f2, g2_old, c_r, h_old, gamma)
[docs]
def dykstra_update_kernel(
self,
k_q: jnp.ndarray,
k_r: jnp.ndarray,
k_g: jnp.ndarray,
gamma: float,
ot_prob: linear_problem.LinearProblem,
min_entry_value: float = 1e-6,
tolerance: float = 1e-3,
min_iter: int = 0,
inner_iter: int = 10,
max_iter: int = 10000
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Run Dykstra's algorithm."""
# shortcuts for problem's definition.
rank = self.rank
n, m = ot_prob.geom.shape
a, b = ot_prob.a, ot_prob.b
supp_a, supp_b = a > 0, b > 0
g_old = k_g
v1_old, v2_old = jnp.ones(rank), jnp.ones(rank)
u1, u2 = jnp.ones(n), jnp.ones(m)
q_gi, q_gp = jnp.ones(rank), jnp.ones(rank)
q_q, q_r = jnp.ones(rank), jnp.ones(rank)
err = jnp.inf
state_inner = u1, u2, v1_old, v2_old, g_old, q_gi, q_gp, q_q, q_r, err
constants = k_q, k_r, k_g, a, b
def cond_fn(
iteration: int, constants: Tuple[jnp.ndarray, ...],
state_inner: Tuple[jnp.ndarray, ...]
) -> bool:
del iteration, constants
*_, err = state_inner
return err > tolerance
def body_fn(
iteration: int, constants: Tuple[jnp.ndarray, ...],
state_inner: Tuple[jnp.ndarray, ...], compute_error: bool
) -> Tuple[jnp.ndarray, ...]:
# TODO(michalk8): in the future, use `NamedTuple`
u1, u2, v1_old, v2_old, g_old, q_gi, q_gp, q_q, q_r, err = state_inner
k_q, k_r, k_g, a, b = constants
# First Projection
u1 = jnp.where(supp_a, a / jnp.dot(k_q, v1_old), 0.0)
u2 = jnp.where(supp_b, b / jnp.dot(k_r, v2_old), 0.0)
g = jnp.maximum(min_entry_value, g_old * q_gi)
q_gi = (g_old * q_gi) / g
g_old = g
# Second Projection
v1_trans = jnp.dot(k_q.T, u1)
v2_trans = jnp.dot(k_r.T, u2)
g = (g_old * q_gp * v1_old * q_q * v1_trans * v2_old * q_r *
v2_trans) ** (1 / 3)
v1 = g / v1_trans
v2 = g / v2_trans
q_gp = (g_old * q_gp) / g
q_q = (q_q * v1_old) / v1
q_r = (q_r * v2_old) / v2
v1_old = v1
v2_old = v2
g_old = g
# Compute Couplings
q, r, _ = recompute_couplings(u1, v1, k_q, u2, v2, k_r, g)
err = jax.lax.cond(
jnp.logical_and(compute_error, iteration >= min_iter),
lambda: solution_error(q, r, ot_prob, self.norm_error)[0], lambda: err
)
return u1, u2, v1_old, v2_old, g_old, q_gi, q_gp, q_q, q_r, err
def recompute_couplings(
u1: jnp.ndarray,
v1: jnp.ndarray,
k_q: jnp.ndarray,
u2: jnp.ndarray,
v2: jnp.ndarray,
k_r: jnp.ndarray,
g: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
q = u1.reshape((-1, 1)) * k_q * v1.reshape((1, -1))
r = u2.reshape((-1, 1)) * k_r * v2.reshape((1, -1))
return q, r, g
state_inner = fixed_point_loop.fixpoint_iter_backprop(
cond_fn, body_fn, min_iter, max_iter, inner_iter, constants, state_inner
)
u1, u2, v1_old, v2_old, g_old, _, _, _, _, _ = state_inner
return recompute_couplings(u1, v1_old, k_q, u2, v2_old, k_r, g_old)
[docs]
def lse_step(
self, ot_prob: linear_problem.LinearProblem, state: LRSinkhornState,
iteration: int
) -> LRSinkhornState:
"""LR Sinkhorn LSE update."""
c_q, c_r, c_g, gamma = self._get_costs(ot_prob, state)
if ot_prob.is_balanced:
c_q, c_r, h = c_q / -gamma, c_r / -gamma, c_g / gamma
q, r, g = self.dykstra_update_lse(
c_q, c_r, h, gamma, ot_prob, **self.kwargs_dys
)
else:
q, r, g = lr_utils.unbalanced_dykstra_lse(
c_q, c_r, c_g, gamma, ot_prob, **self.kwargs_dys
)
return state.set(q=q, g=g, r=r, gamma=gamma)
[docs]
def kernel_step(
self, ot_prob: linear_problem.LinearProblem, state: LRSinkhornState,
iteration: int
) -> LRSinkhornState:
"""LR Sinkhorn Kernel update."""
c_q, c_r, c_g, gamma = self._get_costs(ot_prob, state)
c_q, c_r, c_g = jnp.exp(c_q), jnp.exp(c_r), jnp.exp(c_g)
if ot_prob.is_balanced:
q, r, g = self.dykstra_update_kernel(
c_q, c_r, c_g, gamma, ot_prob, **self.kwargs_dys
)
else:
q, r, g = lr_utils.unbalanced_dykstra_kernel(
c_q, c_r, c_g, gamma, ot_prob, **self.kwargs_dys
)
return state.set(q=q, g=g, r=r, gamma=gamma)
[docs]
def one_iteration(
self, ot_prob: linear_problem.LinearProblem, state: LRSinkhornState,
iteration: int, compute_error: bool
) -> LRSinkhornState:
"""Carries out one low-rank Sinkhorn iteration.
Depending on lse_mode, these iterations can be either in:
- log-space for numerical stability.
- scaling space, using standard kernel-vector multiply operations.
Args:
ot_prob: the transport problem definition
state: LRSinkhornState named tuple.
iteration: the current iteration of the Sinkhorn outer loop.
compute_error: flag to indicate this iteration computes/stores an error
Returns:
The updated state.
"""
previous_state = state
it = iteration // self.inner_iterations
if self.lse_mode: # In lse_mode, run additive updates.
state = self.lse_step(ot_prob, state, iteration)
else:
state = self.kernel_step(ot_prob, state, iteration)
# re-computes error if compute_error is True, else set it to inf.
cost = jax.lax.cond(
jnp.logical_and(compute_error, iteration >= self.min_iterations),
lambda: state.reg_ot_cost(ot_prob, epsilon=self.epsilon),
lambda: jnp.inf
)
error = jax.lax.cond(
iteration >= self.min_iterations,
lambda: state.compute_error(previous_state), lambda: jnp.inf
)
crossed_threshold = jnp.logical_or(
state.crossed_threshold,
jnp.logical_and(
state.errors[it - 1] >= self.threshold, error < self.threshold
)
)
state = state.set(
costs=state.costs.at[it].set(cost),
errors=state.errors.at[it].set(error),
crossed_threshold=crossed_threshold,
)
if self.progress_fn is not None:
jax.debug.callback(
self.progress_fn,
(iteration, self.inner_iterations, self.max_iterations, state)
)
return state
@property
def norm_error(self) -> Tuple[int]: # noqa: D102
return self._norm_error,
[docs]
def init_state(
self, ot_prob: linear_problem.LinearProblem,
init: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]
) -> LRSinkhornState:
"""Return the initial state of the loop."""
q, r, g = init
return LRSinkhornState(
q=q,
r=r,
g=g,
gamma=self.gamma,
costs=-jnp.ones(self.outer_iterations),
errors=-jnp.ones(self.outer_iterations),
crossed_threshold=False,
)
[docs]
def output_from_state(
self, ot_prob: linear_problem.LinearProblem, state: LRSinkhornState
) -> LRSinkhornOutput:
"""Create an output from a loop state.
Args:
ot_prob: the transport problem.
state: a LRSinkhornState.
Returns:
A LRSinkhornOutput.
"""
it = jnp.sum(state.errors != -1.0) * self.inner_iterations
converged = self._converged(state, it)
return LRSinkhornOutput(
q=state.q,
r=state.r,
g=state.g,
ot_prob=ot_prob,
costs=state.costs,
errors=state.errors,
epsilon=self.epsilon,
inner_iterations=self.inner_iterations,
converged=converged,
)
def _converged(self, state: LRSinkhornState, iteration: int) -> bool:
def conv_crossed(prev_err: float, curr_err: float) -> bool:
return jnp.logical_and(
prev_err < self.threshold, curr_err < self.threshold
)
def conv_not_crossed(prev_err: float, curr_err: float) -> bool:
return jnp.logical_and(curr_err < prev_err, curr_err < self.threshold)
# for convergence error, we consider 2 possibilities:
# 1. we either crossed the convergence threshold; in this case we require
# that the previous error was also below the threshold
# 2. we haven't crossed the threshold; in this case, we can be below or
# above the threshold:
# if we're above, we wait until we reach the convergence threshold and
# then, the above condition applies
# if we're below and we improved w.r.t. the previous iteration,
# we have converged; otherwise we continue, since we may be stuck
# in a local minimum (e.g., during the initial iterations)
it = iteration // self.inner_iterations
return jax.lax.cond(
state.crossed_threshold, conv_crossed, conv_not_crossed,
state.errors[it - 2], state.errors[it - 1]
)
def _diverged(self, state: LRSinkhornState, iteration: int) -> bool:
it = iteration // self.inner_iterations - 1
is_not_finite = jnp.logical_and(
jnp.logical_not(jnp.isfinite(state.errors[it])),
jnp.logical_not(jnp.isfinite(state.costs[it]))
)
# `jnp.inf` is used if `it < self.min_iterations`
return jnp.logical_and(it >= self.min_iterations, is_not_finite)
def run(
ot_prob: linear_problem.LinearProblem,
solver: LRSinkhorn,
init: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],
) -> LRSinkhornOutput:
"""Run loop of the solver, outputting a state upgraded to an output."""
out = sinkhorn.iterations(ot_prob, solver, init)
out = out.set_cost(
ot_prob, lse_mode=solver.lse_mode, use_danskin=solver.use_danskin
)
return out.set(ot_prob=ot_prob)