# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import (
Any,
Callable,
Literal,
Mapping,
NamedTuple,
Optional,
Sequence,
Tuple,
Union,
)
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np
from ott import utils
from ott.geometry import geometry
from ott.initializers.linear import initializers as init_lib
from ott.math import fixed_point_loop
from ott.math import unbalanced_functions as uf
from ott.math import utils as mu
from ott.problems.linear import linear_problem, potentials
from ott.solvers.linear import acceleration
from ott.solvers.linear import implicit_differentiation as implicit_lib
__all__ = ["Sinkhorn", "SinkhornOutput"]
ProgressCallbackFn_t = Callable[
[Tuple[np.ndarray, np.ndarray, np.ndarray, "SinkhornState"]], None]
[docs]
class SinkhornState(NamedTuple):
"""Holds the state variables used to solve OT with Sinkhorn."""
potentials: Tuple[jnp.ndarray, ...]
errors: Optional[jnp.ndarray] = None
old_fus: Optional[jnp.ndarray] = None
old_mapped_fus: Optional[jnp.ndarray] = None
[docs]
def set(self, **kwargs: Any) -> "SinkhornState":
"""Return a copy of self, with potential overwrites."""
return self._replace(**kwargs)
[docs]
def solution_error(
self,
ot_prob: linear_problem.LinearProblem,
norm_error: Sequence[int],
*,
lse_mode: bool,
parallel_dual_updates: bool,
recenter: bool,
) -> jnp.ndarray:
"""State dependent function to return error."""
fu, gv = self.fu, self.gv
if recenter and lse_mode:
fu, gv = self.recenter(fu, gv, ot_prob=ot_prob)
return solution_error(
fu,
gv,
ot_prob,
norm_error=norm_error,
lse_mode=lse_mode,
parallel_dual_updates=parallel_dual_updates
)
[docs]
def compute_kl_reg_cost( # noqa: D102
self, ot_prob: linear_problem.LinearProblem, lse_mode: bool
) -> jnp.ndarray:
return compute_kl_reg_cost(self.fu, self.gv, ot_prob, lse_mode)
[docs]
def recenter(
self,
f: jnp.ndarray,
g: jnp.ndarray,
ot_prob: linear_problem.LinearProblem,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Re-center dual potentials.
If the ``ot_prob`` is balanced, the ``f`` potential is zero-centered.
Otherwise, use Prop. 2 of :cite:`sejourne:22` re-center the potentials iff
``tau_a < 1`` and ``tau_b < 1``.
Args:
f: The first dual potential.
g: The second dual potential.
ot_prob: Linear OT problem.
Returns:
The centered potentials.
"""
if ot_prob.is_balanced:
# center the potentials for numerical stability
is_finite = jnp.isfinite(f)
shift = jnp.sum(jnp.where(is_finite, f, 0.0)) / jnp.sum(is_finite)
return f - shift, g + shift
if ot_prob.tau_a == 1.0 or ot_prob.tau_b == 1.0:
# re-centering wasn't done during the lse-step, ignore
return f, g
rho_a = uf.rho(ot_prob.epsilon, ot_prob.tau_a)
rho_b = uf.rho(ot_prob.epsilon, ot_prob.tau_b)
tau = rho_a * rho_b / (rho_a + rho_b)
shift = tau * (
mu.logsumexp(-f / rho_a, b=ot_prob.a) -
mu.logsumexp(-g / rho_b, b=ot_prob.b)
)
return f + shift, g - shift
@property
def fu(self) -> jnp.ndarray:
"""The first dual potential or scaling."""
return self.potentials[0]
@property
def gv(self) -> jnp.ndarray:
"""The second dual potential or scaling."""
return self.potentials[1]
def solution_error(
f_u: jnp.ndarray,
g_v: jnp.ndarray,
ot_prob: linear_problem.LinearProblem,
*,
norm_error: Sequence[int],
lse_mode: bool,
parallel_dual_updates: bool,
) -> jnp.ndarray:
"""Given two potential/scaling solutions, computes deviation to optimality.
When the ``ot_prob`` problem is balanced and the usual Sinkhorn updates are
used, this is simply deviation of the coupling's marginal to ``ot_prob.b``.
This is the case because the second (and last) update of the Sinkhorn
algorithm equalizes the row marginal of the coupling to ``ot_prob.a``. To
simplify the logic, this is parameterized by checking whether
`parallel_dual_updates = False`.
When that flag is `True`, or when the problem is unbalanced,
additional quantities to qualify optimality must be taken into account.
Args:
f_u: jnp.ndarray, potential or scaling
g_v: jnp.ndarray, potential or scaling
ot_prob: linear OT problem
norm_error: int, p-norm used to compute error.
lse_mode: True if log-sum-exp operations, False if kernel vector products.
parallel_dual_updates: Whether potentials/scalings were computed in
parallel.
Returns:
a positive number quantifying how far from optimality current solution is.
"""
if ot_prob.is_balanced and not parallel_dual_updates:
return marginal_error(
f_u, g_v, ot_prob.b, ot_prob.geom, 0, norm_error, lse_mode
)
# In the unbalanced case, we compute the norm of the gradient.
# the gradient is equal to the marginal of the current plan minus
# the gradient of < z, rho_z(exp^(-h/rho_z) -1> where z is either a or b
# and h is either f or g. Note this is equal to z if rho_z → inf, which
# is the case when tau_z → 1.0
if lse_mode:
grad_a = uf.grad_of_marginal_fit(
ot_prob.a, f_u, ot_prob.tau_a, ot_prob.epsilon
)
grad_b = uf.grad_of_marginal_fit(
ot_prob.b, g_v, ot_prob.tau_b, ot_prob.epsilon
)
else:
u = ot_prob.geom.potential_from_scaling(f_u)
v = ot_prob.geom.potential_from_scaling(g_v)
grad_a = uf.grad_of_marginal_fit(
ot_prob.a, u, ot_prob.tau_a, ot_prob.epsilon
)
grad_b = uf.grad_of_marginal_fit(
ot_prob.b, v, ot_prob.tau_b, ot_prob.epsilon
)
err = marginal_error(f_u, g_v, grad_a, ot_prob.geom, 1, norm_error, lse_mode)
err += marginal_error(f_u, g_v, grad_b, ot_prob.geom, 0, norm_error, lse_mode)
return err
def marginal_error(
f_u: jnp.ndarray,
g_v: jnp.ndarray,
target: jnp.ndarray,
geom: geometry.Geometry,
axis: int = 0,
norm_error: Sequence[int] = (1,),
lse_mode: bool = True
) -> jnp.asarray:
"""Output how far Sinkhorn solution is w.r.t target.
Args:
f_u: a vector of potentials or scalings for the first marginal.
g_v: a vector of potentials or scalings for the second marginal.
target: target marginal.
geom: Geometry object.
axis: axis (0 or 1) along which to compute marginal.
norm_error: (tuple of int) p's to compute p-norm between marginal/target
lse_mode: whether operating on scalings or potentials
Returns:
Array of floats, quantifying difference between target / marginal.
"""
if lse_mode:
marginal = geom.marginal_from_potentials(f_u, g_v, axis=axis)
else:
marginal = geom.marginal_from_scalings(f_u, g_v, axis=axis)
norm_error = jnp.asarray(norm_error)
return jnp.sum(
jnp.abs(marginal - target) ** norm_error[:, jnp.newaxis], axis=1
) ** (1.0 / norm_error)
def compute_kl_reg_cost(
f: jnp.ndarray, g: jnp.ndarray, ot_prob: linear_problem.LinearProblem,
lse_mode: bool
) -> jnp.ndarray:
r"""Compute objective of Sinkhorn for OT problem given dual solutions.
The objective is evaluated for dual solution ``f`` and ``g``, using
information contained in ``ot_prob``. The objective is the regularized
optimal transport cost (i.e. the cost itself plus entropic and unbalanced
terms). Situations where marginals ``a`` or ``b`` in ``ot_prob`` have zero
coordinates are reflected in minus infinity entries in their corresponding
dual potentials. To avoid NaN that may result when multiplying 0's by infinity
values, ``jnp.where`` is used to cancel these contributions.
Args:
f: jnp.ndarray, potential
g: jnp.ndarray, potential
ot_prob: linear optimal transport problem.
lse_mode: bool, whether to compute total mass in lse or kernel mode.
Returns:
The regularized transport cost.
"""
supp_a = ot_prob.a > 0
supp_b = ot_prob.b > 0
fa = ot_prob.geom.potential_from_scaling(ot_prob.a)
if ot_prob.tau_a == 1.0:
div_a = jnp.sum(jnp.where(supp_a, ot_prob.a * (f - fa), 0.0))
else:
rho_a = uf.rho(ot_prob.epsilon, ot_prob.tau_a)
div_a = -jnp.sum(
jnp.where(supp_a, ot_prob.a * uf.phi_star(-(f - fa), rho_a), 0.0)
)
gb = ot_prob.geom.potential_from_scaling(ot_prob.b)
if ot_prob.tau_b == 1.0:
div_b = jnp.sum(jnp.where(supp_b, ot_prob.b * (g - gb), 0.0))
else:
rho_b = uf.rho(ot_prob.epsilon, ot_prob.tau_b)
div_b = -jnp.sum(
jnp.where(supp_b, ot_prob.b * uf.phi_star(-(g - gb), rho_b), 0.0)
)
# Using https://arxiv.org/pdf/1910.12958v2.pdf (24)
if lse_mode:
total_sum = jnp.sum(ot_prob.geom.marginal_from_potentials(f, g))
else:
u = ot_prob.geom.scaling_from_potential(f)
v = ot_prob.geom.scaling_from_potential(g)
total_sum = jnp.sum(ot_prob.geom.marginal_from_scalings(u, v))
return div_a + div_b + ot_prob.epsilon * (
jnp.sum(ot_prob.a) * jnp.sum(ot_prob.b) - total_sum
)
[docs]
class SinkhornOutput(NamedTuple):
"""Holds the output of a Sinkhorn solver applied to a problem.
Objects of this class contain both solutions and problem definition of a
regularized OT problem, along several methods that can be used to access its
content, to, for instance, materialize an OT matrix or apply it to a vector
(without having to materialize it when not needed).
Args:
potentials: list of optimal dual variables, two vector of size
``ot.prob.shape[0]`` and ``ot.prob.shape[1]`` returned by Sinkhorn
errors: vector or errors, along iterations. This vector is of size
``max_iterations // inner_iterations`` where those were the parameters
passed on to the :class:`~ott.solvers.linear.sinkhorn.Sinkhorn` solver.
For each entry indexed at ``i``, ``errors[i]`` can be either a real
non-negative value (meaning the algorithm recorded that error at the
``i * inner_iterations`` iteration), a ``jnp.inf`` value (meaning the
algorithm computed that iteration but did not compute its error, because,
for instance, ``i < min_iterations // inner_iterations``), or a ``-1``,
meaning that execution was terminated before that iteration, because the
criterion was found to be smaller than ``threshold``.
reg_ot_cost: the regularized optimal transport cost. By default this is
the linear contribution + KL term. See
:attr:`~ott.solvers.linear.sinkhorn.SinkhornOutput.ent_reg_cost`,
:attr:`~ott.solvers.linear.sinkhorn.SinkhornOutput.primal_cost` and
:attr:`~ott.solvers.linear.sinkhorn.SinkhornOutput.dual_cost` for other
objective values.
ot_prob: stores the definition of the OT problem, including geometry,
marginals, unbalanced regularizers, etc.
threshold: convergence threshold used to control the termination of the
algorithm.
converged: whether the output corresponds to a solution whose error is
below the convergence threshold.
inner_iterations: number of iterations that were run between two
computations of errors.
"""
potentials: Tuple[jnp.ndarray, ...]
errors: Optional[jnp.ndarray] = None
reg_ot_cost: Optional[jnp.ndarray] = None
ot_prob: Optional[linear_problem.LinearProblem] = None
threshold: Optional[jnp.ndarray] = None
converged: Optional[bool] = None
inner_iterations: Optional[int] = None
[docs]
def set(self, **kwargs: Any) -> "SinkhornOutput":
"""Return a copy of self, with potential overwrites."""
return self._replace(**kwargs)
[docs]
def set_cost( # noqa: D102
self, ot_prob: linear_problem.LinearProblem, lse_mode: bool,
use_danskin: bool
) -> "SinkhornOutput":
f = jax.lax.stop_gradient(self.f) if use_danskin else self.f
g = jax.lax.stop_gradient(self.g) if use_danskin else self.g
return self.set(reg_ot_cost=compute_kl_reg_cost(f, g, ot_prob, lse_mode))
@property
def dual_cost(self) -> jnp.ndarray:
"""Return dual transport cost, without considering regularizer."""
a, b = self.ot_prob.a, self.ot_prob.b
dual_cost = jnp.sum(jnp.where(a > 0.0, a * self.f, 0))
dual_cost += jnp.sum(jnp.where(b > 0.0, b * self.g, 0))
return dual_cost
@property
def primal_cost(self) -> jnp.ndarray:
"""Return transport cost of current transport solution at geometry."""
return self.transport_cost_at_geom(other_geom=self.geom)
@property
def ent_reg_cost(self) -> jnp.ndarray:
r"""Entropy regularized cost.
This outputs
.. math::
\langle P^{\star},C\rangle - \varepsilon H(P^{\star}) +
\rho_a\text{KL}(P^{\star} 1|a) + \rho_b\text{KL}(1^T P^{\star}|b),
where :math:`P^{\star}, a, b` is the coupling returned by the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` and the two marginal weight
vectors; :math:`\rho_a=\varepsilon \tau_a / (1-\tau_a)` and
:math:`\rho_b=\varepsilon \tau_b / (1-\tau_b)` are obtained when the problem
is unbalanced from parameters ``tau_a`` and ``tau_b``. Note that the last
two terms vanish in the balanced case, when ``tau_a==tau_b==1``.
"""
ent_a = jnp.sum(jsp.special.entr(self.ot_prob.a))
ent_b = jnp.sum(jsp.special.entr(self.ot_prob.b))
return self.reg_ot_cost - self.geom.epsilon * (ent_a + ent_b)
@property
def kl_reg_cost(self) -> jnp.ndarray:
r"""KL regularized OT transport cost.
This outputs
.. math::
\langle P^{\star}, C \rangle + \varepsilon KL(P^{\star},ab^T) +
\rho_a\text{KL}(P^{\star} 1|a) + \rho_b\text{KL}(1^T P^{\star}|b),
where :math:`P^{\star}, a, b` are the coupling returned by the
:class:`~ott.solvers.linear.sinkhorn.Sinkhorn` algorithm and the two
marginal weight vectors, respectively, and
:math:`\rho_a=\varepsilon \tau_a / (1-\tau_a)` and
:math:`\rho_b=\varepsilon \tau_b / (1-\tau_b)` are obtained when the problem
is unbalanced from parameters ``tau_a`` and ``tau_b``. Note that the last
two terms vanish in the balanced case, when ``tau_a==tau_b==1``. This
quantity coincides with :attr:`reg_ot_cost`, which is computed using
dual variables.
"""
return self.reg_ot_cost
[docs]
def transport_cost_at_geom(
self, other_geom: geometry.Geometry
) -> jnp.ndarray:
r"""Return bare transport cost of current solution at any geometry.
In order to compute cost, we check first if the geometry can be converted
to a low-rank cost geometry in order to speed up computations, without
having to materialize the full cost matrix. If this is not possible,
we resort to instantiating both transport matrix and cost matrix.
Args:
other_geom: geometry whose cost matrix is used to evaluate the transport
cost.
Returns:
the transportation cost at :math:`C`, i.e. :math:`\langle P, C \rangle`.
"""
# TODO(cuturi): handle online mode for non Euclidean pointcloud geometries.
# TODO(michalk8): handle SqEucl point cloud is not converted to LRCGeom
if other_geom.can_LRC:
geom = other_geom.to_LRCGeometry()
return jnp.sum(self.apply(geom.cost_1.T) * geom.cost_2.T)
return jnp.sum(self.matrix * other_geom.cost_matrix)
@property
def geom(self) -> geometry.Geometry: # noqa: D102
return self.ot_prob.geom
@property
def a(self) -> jnp.ndarray: # noqa: D102
return self.ot_prob.a
@property
def b(self) -> jnp.ndarray: # noqa: D102
return self.ot_prob.b
@property
def n_iters(self) -> int: # noqa: D102
"""Returns the total number of iterations that were needed to terminate."""
return jnp.sum(self.errors != -1) * self.inner_iterations
@property
def scalings(self) -> Tuple[jnp.ndarray, jnp.ndarray]: # noqa: D102
u = self.ot_prob.geom.scaling_from_potential(self.f)
v = self.ot_prob.geom.scaling_from_potential(self.g)
return u, v
@property
def matrix(self) -> jnp.ndarray:
"""Transport matrix if it can be instantiated."""
try:
return self.ot_prob.geom.transport_from_potentials(self.f, self.g)
except ValueError:
return self.ot_prob.geom.transport_from_scalings(*self.scalings)
@property
def transport_mass(self) -> jnp.ndarray:
"""Sum of transport matrix."""
return self.marginal(0).sum()
[docs]
def apply(
self,
inputs: jnp.ndarray,
axis: int = 0,
lse_mode: bool = True
) -> jnp.ndarray:
"""Apply the transport to a ndarray; axis=1 for its transpose."""
geom = self.ot_prob.geom
if lse_mode:
return geom.apply_transport_from_potentials(
self.f, self.g, inputs, axis=axis
)
u = geom.scaling_from_potential(self.f)
v = geom.scaling_from_potential(self.g)
return geom.apply_transport_from_scalings(u, v, inputs, axis=axis)
[docs]
def marginal(self, axis: int) -> jnp.ndarray: # noqa: D102
return self.ot_prob.geom.marginal_from_potentials(self.f, self.g, axis=axis)
[docs]
def cost_at_geom(self, other_geom: geometry.Geometry) -> jnp.ndarray:
"""Return reg-OT cost for matrix, evaluated at other cost matrix."""
return (
jnp.sum(self.matrix * other_geom.cost_matrix) -
self.geom.epsilon * jnp.sum(jax.scipy.special.entr(self.matrix))
)
[docs]
def to_dual_potentials(self) -> potentials.EntropicPotentials:
"""Return the entropic map estimator."""
return potentials.EntropicPotentials(self.f, self.g, self.ot_prob)
@property
def f(self) -> jnp.ndarray:
"""The first dual potential."""
return self.potentials[0]
@property
def g(self) -> jnp.ndarray:
"""The second dual potential."""
return self.potentials[1]
[docs]
@jax.tree_util.register_pytree_node_class
class Sinkhorn:
r"""Sinkhorn solver.
The Sinkhorn algorithm is a fixed point iteration that solves a regularized
optimal transport (reg-OT) problem between two measures.
The optimization variables are a pair of vectors (called potentials, or
scalings when parameterized as exponential of the former). Calling this
function returns therefore a pair of optimal vectors. In addition to these,
it also returns the objective value achieved by these optimal vectors;
a vector of size ``max_iterations/inner_iterations`` that records the vector
of values recorded to monitor convergence, throughout the execution of the
algorithm (padded with `-1` if convergence happens before), as well as a
boolean to signify whether the algorithm has converged within the number of
iterations specified by the user.
The reg-OT problem is specified by two measures, of respective sizes ``n`` and
``m``. From the viewpoint of the ``sinkhorn`` function, these two measures are
only seen through a triplet (``geom``, ``a``, ``b``), where ``geom`` is a
``Geometry`` object, and ``a`` and ``b`` are weight vectors of respective
sizes ``n`` and ``m``. Starting from two initial values for those potentials
or scalings (both can be defined by the user by passing value in
``init_dual_a`` or ``init_dual_b``), the Sinkhorn algorithm will use
elementary operations that are carried out by the ``geom`` object.
Math:
Given a geometry ``geom``, which provides a cost matrix :math:`C` with its
regularization parameter :math:`\varepsilon`, (or a kernel matrix :math:`K`)
the reg-OT problem consists in finding two vectors `f`, `g` of size ``n``,
``m`` that maximize the following criterion.
.. math::
\arg\max_{f, g}{- \langle a, \phi_a^{*}(-f) \rangle - \langle b,
\phi_b^{*}(-g) \rangle - \varepsilon \langle e^{f/\varepsilon},
e^{-C/\varepsilon} e^{g/\varepsilon}} \rangle
where :math:`\phi_a(z) = \rho_a z(\log z - 1)` is a scaled entropy, and
:math:`\phi_a^{*}(z) = \rho_a e^{z/\varepsilon}`, its Legendre transform
:cite:`sejourne:19`.
That problem can also be written, instead, using positive scaling vectors
`u`, `v` of size ``n``, ``m``, handled with the kernel
:math:`K := e^{-C/\varepsilon}`,
.. math::
\arg\max_{u, v >0} - \langle a,\phi_a^{*}(-\varepsilon\log u) \rangle +
\langle b, \phi_b^{*}(-\varepsilon\log v) \rangle - \langle u, K v \rangle
Both of these problems corresponds, in their *primal* formulation, to
solving the unbalanced optimal transport problem with a variable matrix
:math:`P` of size ``n`` x ``m``:
.. math::
\arg\min_{P>0} \langle P,C \rangle +\varepsilon \text{KL}(P | ab^T)
+ \rho_a \text{KL}(P\mathbf{1}_m | a) + \rho_b \text{KL}(P^T \mathbf{1}_n
| b)
where :math:`KL` is the generalized Kullback-Leibler divergence.
The very same primal problem can also be written using a kernel :math:`K`
instead of a cost :math:`C` as well:
.. math::
\arg\min_{P} \varepsilon \text{KL}(P|K)
+ \rho_a \text{KL}(P\mathbf{1}_m | a) +
\rho_b \text{KL}(P^T \mathbf{1}_n | b)
The *original* OT problem taught in linear programming courses is recovered
by using the formulation above relying on the cost :math:`C`, and letting
:math:`\varepsilon \rightarrow 0`, and :math:`\rho_a, \rho_b \rightarrow
\infty`.
In that case the entropy disappears, whereas the :math:`KL` regularization
above become constraints on the marginals of :math:`P`: This results in a
standard min cost flow problem. This problem is not handled for now in this
toolbox, which focuses exclusively on the case :math:`\varepsilon > 0`.
The *balanced* regularized OT problem is recovered for finite
:math:`\varepsilon > 0` but letting :math:`\rho_a, \rho_b \rightarrow
\infty`. This problem can be shown to be equivalent to a matrix scaling
problem, which can be solved using the Sinkhorn fixed-point algorithm.
To handle the case :math:`\rho_a, \rho_b \rightarrow \infty`, the
``sinkhorn`` function uses parameters ``tau_a`` and ``tau_b`` equal
respectively to :math:`\rho_a /(\varepsilon + \rho_a)` and
:math:`\rho_b / (\varepsilon + \rho_b)` instead. Setting either of these
parameters to 1 corresponds to setting the corresponding
:math:`\rho_a, \rho_b` to :math:`\infty`.
The Sinkhorn algorithm solves the reg-OT problem by seeking optimal
:math:`f`, :math:`g` potentials (or alternatively their parameterization
as positive scaling vectors :math:`u`, :math:`v`), rather than solving the
primal problem in :math:`P`. This is mostly for efficiency (potentials and
scalings have a ``n + m`` memory footprint, rather than ``n m`` required
to store `P`). This is also because both problems are, in fact, equivalent,
since the optimal transport :math:`P^{\star}` can be recovered from
optimal potentials :math:`f^{\star}`, :math:`g^{\star}` or scaling
:math:`u^{\star}`, :math:`v^{\star}`, using the geometry's cost or kernel
matrix respectively:
.. math::
P^{\star} = \exp\left(\frac{f^{\star}\mathbf{1}_m^T + \mathbf{1}_n g^{*T}-
C}{\varepsilon}\right) \text{ or } P^{\star} = \text{diag}(u^{\star}) K
\text{diag}(v^{\star})
By default, the Sinkhorn algorithm solves this dual problem in :math:`f, g`
or :math:`u, v` using block coordinate ascent, i.e. devising an update for
each :math:`f` and :math:`g` (resp. :math:`u` and :math:`v`) that cancels
their respective gradients, one at a time. These two iterations are repeated
``inner_iterations`` times, after which the norm of these gradients will be
evaluated and compared with the ``threshold`` value. The iterations are then
repeated as long as that error exceeds ``threshold``.
Note on Sinkhorn updates:
The boolean flag ``lse_mode`` sets whether the algorithm is run in either:
- log-sum-exp mode (``lse_mode=True``), in which case it is directly
defined in terms of updates to `f` and `g`, using log-sum-exp
computations. This requires access to the cost matrix :math:`C`, as it is
stored, or possibly computed on the fly by ``geom``.
- kernel mode (``lse_mode=False``), in which case it will require access
to a matrix vector multiplication operator :math:`z \rightarrow K z`,
where :math:`K` is either instantiated from :math:`C` as
:math:`\exp(-C/\varepsilon)`, or provided directly. In that case, rather
than optimizing on :math:`f` and :math:`g`, it is more convenient to
optimize on their so called scaling formulations,
:math:`u := \exp(f / \varepsilon)` and :math:`v := \exp(g / \varepsilon)`.
While faster (applying matrices is faster than applying ``lse`` repeatedly
over lines), this mode is also less stable numerically, notably for
smaller :math:`\varepsilon`.
In the source code, the variables ``f_u`` or ``g_v`` can be either regarded
as potentials (real) or scalings (positive) vectors, depending on the choice
of ``lse_mode`` by the user. Once optimization is carried out, we only
return dual variables in potential form, i.e. ``f`` and ``g``.
In addition to standard Sinkhorn updates, the user can also use heavy-ball
type updates using a ``momentum`` parameter in ]0,2[. We also implement a
strategy that tries to set that parameter adaptively at
``chg_momentum_from`` iterations, as a function of progress in the error,
as proposed in the literature.
Another upgrade to the standard Sinkhorn updates provided to the users lies
in using Anderson acceleration. This can be parameterized by setting the
otherwise null ``anderson`` to a positive integer. When selected,the
algorithm will recompute, every ``refresh_anderson_frequency`` (set by
default to 1) an extrapolation of the most recently computed ``anderson``
iterates. When using that option, notice that differentiation (if required)
can only be carried out using implicit differentiation, and that all
momentum related parameters are ignored.
The ``parallel_dual_updates`` flag is set to ``False`` by default. In that
setting, ``g_v`` is first updated using the latest values for ``f_u`` and
``g_v``, before proceeding to update ``f_u`` using that new value for
``g_v``. When the flag is set to ``True``, both ``f_u`` and ``g_v`` are
updated simultaneously. Note that setting that choice to ``True`` requires
using some form of averaging (e.g. ``momentum=0.5``). Without this, and on
its own ``parallel_dual_updates`` won't work.
Differentiation:
The optimal solutions ``f`` and ``g`` and the optimal objective
(``reg_ot_cost``) outputted by the Sinkhorn algorithm can be differentiated
w.r.t. relevant inputs ``geom``, ``a`` and ``b``. In the default setting,
implicit differentiation of the optimality conditions (``implicit_diff``
not equal to ``None``), this has two consequences, treating ``f`` and ``g``
differently from ``reg_ot_cost``.
- The termination criterion used to stop Sinkhorn (cancellation of
gradient of objective w.r.t. ``f_u`` and ``g_v``) is used to differentiate
``f`` and ``g``, given a change in the inputs. These changes are computed
by solving a linear system. The arguments starting with
``implicit_solver_*`` allow to define the linear solver that is used, and
to control for two types or regularization (we have observed that,
depending on the architecture, linear solves may require higher ridge
parameters to remain stable). The optimality conditions in Sinkhorn can be
analyzed as satisfying a ``z=z'`` condition, which are then
differentiated. It might be beneficial (e.g., as in :cite:`cuturi:20a`)
to use a preconditioning function ``precondition_fun`` to differentiate
instead ``h(z) = h(z')``.
- The objective ``reg_ot_cost`` returned by Sinkhorn uses the so-called
envelope (or Danskin's) theorem. In that case, because it is assumed that
the gradients of the dual variables ``f_u`` and ``g_v`` w.r.t. dual
objective are zero (reflecting the fact that they are optimal), small
variations in ``f_u`` and ``g_v`` due to changes in inputs (such as
``geom``, ``a`` and ``b``) are considered negligible. As a result,
``stop_gradient`` is applied on dual variables ``f_u`` and ``g_v`` when
evaluating the ``reg_ot_cost`` objective. Note that this approach is
`invalid` when computing higher order derivatives. In that case the
``use_danskin`` flag must be set to ``False``.
An alternative yet more costly way to differentiate the outputs of the
Sinkhorn iterations is to use unrolling, i.e. reverse mode differentiation
of the Sinkhorn loop. This is possible because Sinkhorn iterations are
wrapped in a custom fixed point iteration loop, defined in
``fixed_point_loop``, rather than a standard while loop. This is to ensure
the end result of this fixed point loop can also be differentiated, if
needed, using standard JAX operations. To ensure differentiability,
the ``fixed_point_loop.fixpoint_iter_backprop`` loop does checkpointing of
state variables (here ``f_u`` and ``g_v``) every ``inner_iterations``, and
backpropagates automatically, block by block, through blocks of
``inner_iterations`` at a time.
Note:
* The Sinkhorn algorithm may not converge within the maximum number of
iterations for possibly several reasons:
1. the regularizer (defined as ``epsilon`` in the geometry ``geom``
object) is too small. Consider either switching to ``lse_mode=True``
(at the price of a slower execution), increasing ``epsilon``, or,
alternatively, if you are unable or unwilling to increase ``epsilon``,
either increase ``max_iterations`` or ``threshold``.
2. the probability weights ``a`` and ``b`` do not have the same total
mass, while using a balanced (``tau_a=tau_b=1.0``) setup.
Consider either normalizing ``a`` and ``b``, or set either ``tau_a``
and/or ``tau_b<1.0``.
3. OOMs issues may arise when storing either cost or kernel matrices that
are too large in ``geom``. In the case where, the ``geom`` geometry is
a ``PointCloud``, some of these issues might be solved by setting the
``online`` flag to ``True``. This will trigger a re-computation on the
fly of the cost/kernel matrix.
* The weight vectors ``a`` and ``b`` can be passed on with coordinates that
have zero weight. This is then handled by relying on simple arithmetic for
``inf`` values that will likely arise (due to :math:`\log 0` when
``lse_mode`` is ``True``, or divisions by zero when ``lse_mode`` is
``False``). Whenever that arithmetic is likely to produce ``NaN`` values
(due to ``-inf * 0``, or ``-inf - -inf``) in the forward pass, we use
``jnp.where`` conditional statements to carry ``inf`` rather than ``NaN``
values. In the reverse mode differentiation, the inputs corresponding to
these 0 weights (a location `x`, or a row in the corresponding cost/kernel
matrix), and the weight itself will have ``NaN`` gradient values. This is
reflects that these gradients are undefined, since these points were not
considered in the optimization and have therefore no impact on the output.
Args:
lse_mode: ``True`` for log-sum-exp computations, ``False`` for kernel
multiplication.
threshold: tolerance used to stop the Sinkhorn iterations. This is
typically the deviation between a target marginal and the marginal of the
current primal solution when either or both tau_a and tau_b are 1.0
(balanced or semi-balanced problem), or the relative change between two
successive solutions in the unbalanced case.
norm_error: power used to define p-norm of error for marginal/target.
inner_iterations: the Sinkhorn error is not recomputed at each
iteration but every ``inner_iterations`` instead.
min_iterations: the minimum number of Sinkhorn iterations carried
out before the error is computed and monitored.
max_iterations: the maximum number of Sinkhorn iterations. If
``max_iterations`` is equal to ``min_iterations``, Sinkhorn iterations are
run by default using a :func:`jax.lax.scan` loop rather than a custom,
unroll-able :func:`jax.lax.while_loop` that monitors convergence.
In that case the error is not monitored and the ``converged``
flag will return ``False`` as a consequence.
momentum: Momentum instance.
anderson: AndersonAcceleration instance.
implicit_diff: instance used to solve implicit differentiation. Unrolls
iterations if None.
parallel_dual_updates: updates potentials or scalings in parallel if True,
sequentially (in Gauss-Seidel fashion) if False.
recenter_potentials: Whether to re-center the dual potentials.
If the problem is balanced, the ``f`` potential is zero-centered for
numerical stability. Otherwise, use the approach of :cite:`sejourne:22`
to achieve faster convergence. Only used when ``lse_mode = True`` and
``tau_a < 1`` and ``tau_b < 1``.
use_danskin: when ``True``, it is assumed the entropy regularized cost
is evaluated using optimal potentials that are frozen, i.e. whose
gradients have been stopped. This is useful when carrying out first order
differentiation, and is only valid (as with ``implicit_differentiation``)
when the algorithm has converged with a low tolerance.
initializer: how to compute the initial potentials/scalings. This refers to
a few possible classes implemented following the template in
:class:`~ott.initializers.linear.SinkhornInitializer`.
progress_fn: callback function which gets called during the Sinkhorn
iterations, so the user can display the error at each iteration,
e.g., using a progress bar. See :func:`~ott.utils.default_progress_fn`
for a basic implementation.
kwargs_init: keyword arguments when creating the initializer.
"""
def __init__(
self,
lse_mode: bool = True,
threshold: float = 1e-3,
norm_error: int = 1,
inner_iterations: int = 10,
min_iterations: int = 0,
max_iterations: int = 2000,
momentum: Optional[acceleration.Momentum] = None,
anderson: Optional[acceleration.AndersonAcceleration] = None,
parallel_dual_updates: bool = False,
recenter_potentials: bool = False,
use_danskin: Optional[bool] = None,
implicit_diff: Optional[implicit_lib.ImplicitDiff
] = implicit_lib.ImplicitDiff(), # noqa: B008
initializer: Union[Literal["default", "gaussian", "sorting", "subsample"],
init_lib.SinkhornInitializer] = "default",
progress_fn: Optional[ProgressCallbackFn_t] = None,
kwargs_init: Optional[Mapping[str, Any]] = None,
):
self.lse_mode = lse_mode
self.threshold = threshold
self.inner_iterations = inner_iterations
self.min_iterations = min_iterations
self.max_iterations = max_iterations
self._norm_error = norm_error
self.anderson = anderson
self.implicit_diff = implicit_diff
if momentum is not None:
self.momentum = acceleration.Momentum(
momentum.start, momentum.error_threshold, momentum.value,
self.inner_iterations
)
else:
# Use no momentum if using Anderson or unrolling.
if self.anderson is not None or self.implicit_diff is None:
self.momentum = acceleration.Momentum(
inner_iterations=self.inner_iterations
)
else:
# no momentum
self.momentum = acceleration.Momentum()
self.parallel_dual_updates = parallel_dual_updates
self.recenter_potentials = recenter_potentials
self.initializer = initializer
self.progress_fn = progress_fn
self.kwargs_init = {} if kwargs_init is None else kwargs_init
# Force implicit_differentiation to True when using Anderson acceleration,
# Reset all momentum parameters to default (i.e. no momentum)
if anderson:
self.implicit_diff = (
implicit_lib.ImplicitDiff()
if self.implicit_diff is None else self.implicit_diff
)
self.momentum = acceleration.Momentum(
inner_iterations=self.inner_iterations
)
# By default, use Danskin theorem to differentiate
# the objective when using implicit_lib.
self.use_danskin = ((self.implicit_diff is not None)
if use_danskin is None else use_danskin)
def __call__(
self,
ot_prob: linear_problem.LinearProblem,
init: Tuple[Optional[jnp.ndarray], Optional[jnp.ndarray]] = (None, None),
rng: Optional[jax.Array] = None,
) -> SinkhornOutput:
"""Run Sinkhorn algorithm.
Args:
ot_prob: Linear OT problem.
init: Initial dual potentials/scalings f_u and g_v, respectively.
Any `None` values will be initialized using the initializer.
rng: Random number generator key for stochastic initialization.
Returns:
The Sinkhorn output.
"""
rng = utils.default_prng_key(rng)
initializer = self.create_initializer()
init_dual_a, init_dual_b = initializer(
ot_prob, *init, lse_mode=self.lse_mode, rng=rng
)
return run(ot_prob, self, (init_dual_a, init_dual_b))
[docs]
def lse_step(
self, ot_prob: linear_problem.LinearProblem, state: SinkhornState,
iteration: int
) -> SinkhornState:
"""Sinkhorn LSE update."""
def k(tau_i: float, tau_j: float) -> float:
num = -tau_j * (tau_a - 1) * (tau_b - 1) * (tau_i - 1)
denom = (tau_j - 1) * (tau_a * (tau_b - 1) + tau_b * (tau_a - 1))
return num / denom
def xi(tau_i: float, tau_j: float) -> float:
k_ij = k(tau_i, tau_j)
return k_ij / (1.0 - k_ij)
def smin(
potential: jnp.ndarray, marginal: jnp.ndarray, tau: float
) -> float:
rho = uf.rho(ot_prob.epsilon, tau)
return -rho * mu.logsumexp(-potential / rho, b=marginal)
# only for an unbalanced problems with `tau_{a,b} < 1`
recenter = (
self.recenter_potentials and ot_prob.tau_a < 1.0 and ot_prob.tau_b < 1.0
)
w = self.momentum.weight(state, iteration)
tau_a, tau_b = ot_prob.tau_a, ot_prob.tau_b
old_fu, old_gv = state.fu, state.gv
if recenter:
k11, k22 = k(tau_a, tau_a), k(tau_b, tau_b)
xi12, xi21 = xi(tau_a, tau_b), xi(tau_b, tau_a)
# update g potential
new_gv = tau_b * ot_prob.geom.update_potential(
old_fu, old_gv, jnp.log(ot_prob.b), iteration, axis=0
)
if recenter:
new_gv -= k22 * smin(old_fu, ot_prob.a, tau_a)
new_gv += xi21 * smin(new_gv, ot_prob.b, tau_b)
gv = self.momentum(w, old_gv, new_gv, self.lse_mode)
if not self.parallel_dual_updates:
old_gv = gv
# update f potential
new_fu = tau_a * ot_prob.geom.update_potential(
old_fu, old_gv, jnp.log(ot_prob.a), iteration, axis=1
)
if recenter:
new_fu -= k11 * smin(old_gv, ot_prob.b, tau_b)
new_fu += xi12 * smin(new_fu, ot_prob.a, tau_a)
fu = self.momentum(w, old_fu, new_fu, self.lse_mode)
return state.set(potentials=(fu, gv))
[docs]
def kernel_step(
self, ot_prob: linear_problem.LinearProblem, state: SinkhornState,
iteration: int
) -> SinkhornState:
"""Sinkhorn multiplicative update."""
w = self.momentum.weight(state, iteration)
old_gv = state.gv
new_gv = ot_prob.geom.update_scaling(
state.fu, ot_prob.b, iteration, axis=0
) ** ot_prob.tau_b
gv = self.momentum(w, state.gv, new_gv, self.lse_mode)
new_fu = ot_prob.geom.update_scaling(
old_gv if self.parallel_dual_updates else gv,
ot_prob.a,
iteration,
axis=1
) ** ot_prob.tau_a
fu = self.momentum(w, state.fu, new_fu, self.lse_mode)
return state.set(potentials=(fu, gv))
[docs]
def one_iteration(
self, ot_prob: linear_problem.LinearProblem, state: SinkhornState,
iteration: int, compute_error: bool
) -> SinkhornState:
"""Carries out one Sinkhorn iteration.
Depending on lse_mode, these iterations can be either in:
- log-space for numerical stability.
- scaling space, using standard kernel-vector multiply operations.
Args:
ot_prob: the transport problem definition
state: SinkhornState named tuple.
iteration: the current iteration of the Sinkhorn loop.
compute_error: flag to indicate this iteration computes/stores an error
Returns:
The updated state.
"""
# When running updates in parallel (Gauss-Seidel mode), old_g_v will be
# used to update f_u, rather than the latest g_v computed in this loop.
# Unused otherwise.
if self.anderson:
state = self.anderson.update(state, iteration, ot_prob, self.lse_mode)
if self.lse_mode: # In lse_mode, run additive updates.
state = self.lse_step(ot_prob, state, iteration)
else:
state = self.kernel_step(ot_prob, state, iteration)
if self.anderson:
state = self.anderson.update_history(state, ot_prob, self.lse_mode)
# re-computes error if compute_error is True, else set it to inf.
err = jax.lax.cond(
jnp.logical_or(
iteration == self.max_iterations - 1,
jnp.logical_and(compute_error, iteration >= self.min_iterations)
),
lambda state, prob: state.solution_error(
prob,
self.norm_error,
lse_mode=self.lse_mode,
parallel_dual_updates=self.parallel_dual_updates,
recenter=self.recenter_potentials
)[0],
lambda *_: jnp.array(jnp.inf, dtype=ot_prob.dtype),
state,
ot_prob,
)
errors = state.errors.at[iteration // self.inner_iterations, :].set(err)
state = state.set(errors=errors)
if self.progress_fn is not None:
jax.debug.callback(
self.progress_fn,
(iteration, self.inner_iterations, self.max_iterations, state)
)
return state
def _converged(self, state: SinkhornState, iteration: int) -> bool:
err = state.errors[iteration // self.inner_iterations - 1, 0]
return jnp.logical_and(iteration > 0, err < self.threshold)
def _diverged(self, state: SinkhornState, iteration: int) -> bool:
err = state.errors[iteration // self.inner_iterations - 1, 0]
return jnp.logical_not(jnp.isfinite(err))
def _continue(self, state: SinkhornState, iteration: int) -> bool:
"""Continue while not(converged) and not(diverged)."""
return jnp.logical_and(
jnp.logical_not(self._diverged(state, iteration)),
jnp.logical_not(self._converged(state, iteration))
)
@property
def outer_iterations(self) -> int:
"""Upper bound on number of times inner_iterations are carried out.
This integer can be used to set constant array sizes to track the algorithm
progress, notably errors.
"""
return np.ceil(self.max_iterations / self.inner_iterations).astype(int)
[docs]
def init_state(
self, ot_prob: linear_problem.LinearProblem, init: Tuple[jnp.ndarray,
jnp.ndarray]
) -> SinkhornState:
"""Return the initial state of the loop."""
errors = -jnp.ones((self.outer_iterations, len(self.norm_error)),
dtype=ot_prob.dtype)
state = SinkhornState(init, errors=errors)
return self.anderson.init_maps(ot_prob, state) if self.anderson else state
[docs]
def output_from_state(
self, ot_prob: linear_problem.LinearProblem, state: SinkhornState
) -> SinkhornOutput:
"""Create an output from a loop state.
Note:
When differentiating the regularized OT cost, and assuming Sinkhorn has
run to convergence, Danskin's (or the envelope)
`theorem <https://en.wikipedia.org/wiki/Danskin%27s_theorem>`_
:cite:`danskin:67,bertsekas:71`
states that the resulting OT cost as a function of the inputs
(``geometry``, ``a``, ``b``) behaves locally as if the dual optimal
potentials were frozen and did not vary with those inputs.
Notice this is only valid, as when using ``implicit_differentiation``
mode, if the Sinkhorn algorithm outputs potentials that are near optimal.
namely when the threshold value is set to a small tolerance.
The flag ``use_danskin`` controls whether that assumption is made. By
default, that flag is set to the value of ``implicit_differentiation`` if
not specified. If you wish to compute derivatives of order 2 and above,
set ``use_danskin`` to ``False``.
Args:
ot_prob: the transport problem.
state: a SinkhornState.
Returns:
A SinkhornOutput.
"""
geom = ot_prob.geom
f = state.fu if self.lse_mode else geom.potential_from_scaling(state.fu)
g = state.gv if self.lse_mode else geom.potential_from_scaling(state.gv)
if self.recenter_potentials:
f, g = state.recenter(f, g, ot_prob=ot_prob)
# By convention, the algorithm is said to have converged if the algorithm
# has not nan'ed during iterations (notice some errors might be infinite,
# this convention is used when the error is not recomputed), and if the
# last recorded error is lower than the threshold. Note that this will be
# the case if either the algorithm terminated earlier (in which case the
# last state.errors[-1] = -1 by convention) or if the algorithm carried out
# the maximal number of iterations and its last recorded error (at -1
# position) is lower than the threshold.
converged = jnp.logical_and(
jnp.logical_not(jnp.any(jnp.isnan(state.errors))), state.errors[-1]
< self.threshold
)[0]
return SinkhornOutput((f, g),
errors=state.errors[:, 0],
threshold=jnp.array(self.threshold),
converged=converged,
inner_iterations=self.inner_iterations)
@property
def norm_error(self) -> Tuple[int, ...]:
"""Powers used to compute the p-norm between marginal/target."""
# To change momentum adaptively, one needs errors in ||.||_1 norm.
# In that case, we add this exponent to the list of errors to compute,
# notably if that was not the error requested by the user.
if self.momentum and self.momentum.start > 0 and self._norm_error != 1:
return self._norm_error, 1
return self._norm_error,
# TODO(michalk8): in the future, enforce this (+ in GW) via abstract method
[docs]
def create_initializer(self) -> init_lib.SinkhornInitializer: # noqa: D102
if isinstance(self.initializer, init_lib.SinkhornInitializer):
return self.initializer
if self.initializer == "default":
return init_lib.DefaultInitializer()
if self.initializer == "gaussian":
return init_lib.GaussianInitializer()
if self.initializer == "sorting":
return init_lib.SortingInitializer(**self.kwargs_init)
if self.initializer == "subsample":
return init_lib.SubsampleInitializer(**self.kwargs_init)
raise NotImplementedError(
f"Initializer `{self.initializer}` is not yet implemented."
)
def tree_flatten(self): # noqa: D102
aux = vars(self).copy()
aux["norm_error"] = aux.pop("_norm_error")
aux.pop("threshold")
return [self.threshold], aux
@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
return cls(**aux_data, threshold=children[0])
def run(
ot_prob: linear_problem.LinearProblem, solver: Sinkhorn,
init: Tuple[jnp.ndarray, ...]
) -> SinkhornOutput:
"""Run loop of the solver, outputting a state upgraded to an output."""
iter_fun = _iterations_implicit if solver.implicit_diff else iterations
out = iter_fun(ot_prob, solver, init)
# Be careful here, the geom and the cost are injected at the end, where it
# does not interfere with the implicit differentiation.
out = out.set_cost(ot_prob, solver.lse_mode, solver.use_danskin)
return out.set(ot_prob=ot_prob)
def iterations(
ot_prob: linear_problem.LinearProblem, solver: Sinkhorn,
init: Tuple[jnp.ndarray, ...]
) -> SinkhornOutput:
"""Jittable Sinkhorn loop. args contain initialization variables."""
def cond_fn(
iteration: int, const: Tuple[linear_problem.LinearProblem, Sinkhorn],
state: SinkhornState
) -> bool:
_, solver = const
return solver._continue(state, iteration)
def body_fn(
iteration: int, const: Tuple[linear_problem.LinearProblem, Sinkhorn],
state: SinkhornState, compute_error: bool
) -> SinkhornState:
ot_prob, solver = const
return solver.one_iteration(ot_prob, state, iteration, compute_error)
# Run the Sinkhorn loop. Choose either a standard fixpoint_iter loop if
# differentiation is implicit, otherwise switch to the backprop friendly
# version of that loop if unrolling to differentiate.
if solver.implicit_diff:
fix_point = fixed_point_loop.fixpoint_iter
else:
fix_point = fixed_point_loop.fixpoint_iter_backprop
const = ot_prob, solver
state = solver.init_state(ot_prob, init)
state = fix_point(
cond_fn, body_fn, solver.min_iterations, solver.max_iterations,
solver.inner_iterations, const, state
)
return solver.output_from_state(ot_prob, state)
def _iterations_taped(
ot_prob: linear_problem.LinearProblem, solver: Sinkhorn,
init: Tuple[jnp.ndarray, ...]
) -> Tuple[SinkhornOutput, Tuple[jnp.ndarray, jnp.ndarray,
linear_problem.LinearProblem, Sinkhorn]]:
"""Run forward pass of the Sinkhorn algorithm storing side information."""
state = iterations(ot_prob, solver, init)
return state, (state.f, state.g, ot_prob, solver)
def _iterations_implicit_bwd(res, gr: SinkhornOutput):
"""Run Sinkhorn in backward mode, using implicit differentiation.
Args:
res: residual data sent from fwd pass, used for computations below. In this
case consists in the output itself, as well as inputs against which we
wish to differentiate.
gr: gradients w.r.t outputs of fwd pass, here w.r.t size f, g, errors. Note
that differentiability w.r.t. errors is not handled, and only f, g is
considered.
Returns:
a tuple of gradients: PyTree for geom, one jnp.ndarray for each of a and b.
"""
f, g, ot_prob, solver = res
out = solver.implicit_diff.gradient(
ot_prob, f, g, solver.lse_mode, gr.potentials
)
return *out, None, None
# sets threshold, norm_errors, geom, a and b to be differentiable, as those are
# non-static. Only differentiability w.r.t. geom, a and b will be used.
_iterations_implicit = jax.custom_vjp(iterations)
_iterations_implicit.defvjp(_iterations_taped, _iterations_implicit_bwd)