Source code for ott.solvers.quadratic.gromov_wasserstein_lr

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A Jax implementation of the unbalanced low-rank GW algorithm."""
from typing import Any, Callable, Mapping, NamedTuple, Optional, Tuple

import jax
import jax.numpy as jnp
import jax.scipy as jsp
import numpy as np

from ott import utils
from ott.geometry import geometry, low_rank
from ott.initializers.linear import initializers_lr
from ott.math import fixed_point_loop
from ott.math import unbalanced_functions as uf
from ott.math import utils as mu
from ott.problems.quadratic import quadratic_problem
from ott.solvers.linear import lr_utils, sinkhorn

__all__ = ["LRGromovWasserstein", "LRGWOutput"]

ProgressFunction = Callable[
    [Tuple[np.ndarray, np.ndarray, np.ndarray, "LRGWState"]], None]


class LRGWState(NamedTuple):
  """State of the low-rank GW algorithm."""
  q: jnp.ndarray
  r: jnp.ndarray
  g: jnp.ndarray
  gamma: float
  costs: jnp.ndarray
  errors: jnp.ndarray
  crossed_threshold: bool

  def compute_error(  # noqa: D102
      self, previous_state: "LRGWState"
  ) -> float:
    err_q = mu.gen_js(self.q, previous_state.q, c=1.0)
    err_r = mu.gen_js(self.r, previous_state.r, c=1.0)
    err_g = mu.gen_js(self.g, previous_state.g, c=1.0)

    return ((1.0 / self.gamma) ** 2) * (err_q + err_r + err_g)

  def reg_gw_cost(  # noqa: D102
    self,
    ot_prob: quadratic_problem.QuadraticProblem,
    *,
    epsilon: float,
    use_danskin: bool = False
  ) -> float:
    return compute_reg_gw_cost(
        self.q,
        self.r,
        self.g,
        ot_prob,
        epsilon=epsilon,
        use_danskin=use_danskin
    )

  def set(self, **kwargs: Any) -> "LRGWState":
    """Return a copy of self, with potential overwrites."""
    return self._replace(**kwargs)


def compute_reg_gw_cost(
    q: jnp.ndarray,
    r: jnp.ndarray,
    g: jnp.ndarray,
    ot_prob: quadratic_problem.QuadraticProblem,
    epsilon: float,
    use_danskin: bool = False
) -> float:
  """Compute the regularized OT cost, here the primal cost of the LR solution.

  Args:
    q: first factor of solution
    r: second factor of solution
    g: weights of solution
    ot_prob: linear problem
    epsilon: Entropic regularization.
    use_danskin: if True, use Danskin's theorem :cite:`danskin:67,bertsekas:71`
      to avoid computing the gradient of the cost function.

  Returns:
    regularized OT cost, the (primal) transport cost of the low-rank solution.
  """

  def ent(x: jnp.ndarray) -> float:
    # generalized entropy
    return jnp.sum(jsp.special.entr(x) + x)

  q = jax.lax.stop_gradient(q) if use_danskin else q
  r = jax.lax.stop_gradient(r) if use_danskin else r
  g = jax.lax.stop_gradient(g) if use_danskin else g

  out = LRGWOutput(
      q=q,
      r=r,
      g=g,
      ot_prob=ot_prob,
      costs=None,
      errors=None,
      epsilon=None,
      inner_iterations=None,
      converged=False,
  )

  cost = out.primal_cost - epsilon * (ent(q) + ent(r) + ent(g))
  if ot_prob.tau_a != 1.0:
    rho_a = uf.rho(1.0, ot_prob.tau_a)
    cost += rho_a * mu.gen_kl(jnp.sum(q, axis=1), ot_prob.a)
  if ot_prob.tau_b != 1.0:
    rho_b = uf.rho(1.0, ot_prob.tau_b)
    cost += rho_b * mu.gen_kl(jnp.sum(r, axis=1), ot_prob.b)

  return cost


[docs] class LRGWOutput(NamedTuple): """Transport interface for a low-rank GW solution.""" q: jnp.ndarray r: jnp.ndarray g: jnp.ndarray costs: jnp.ndarray # TODO(michalk8): must be called `errors`, because of `store_inner_errors` # in future, enforce via class hierarchy errors: jnp.ndarray ot_prob: quadratic_problem.QuadraticProblem epsilon: float inner_iterations: int converged: bool reg_gw_cost: Optional[float] = None
[docs] def set(self, **kwargs: Any) -> "LRGWOutput": """Return a copy of self, with potential overwrites.""" return self._replace(**kwargs)
[docs] def set_cost( # noqa: D102 self, ot_prob: quadratic_problem.QuadraticProblem, lse_mode: bool, use_danskin: bool = False ) -> "LRGWOutput": del lse_mode return self.set(reg_gw_cost=self.compute_reg_gw_cost(ot_prob, use_danskin))
[docs] def compute_reg_gw_cost( # noqa: D102 self, ot_prob: quadratic_problem.QuadraticProblem, use_danskin: bool = False, ) -> float: return compute_reg_gw_cost( self.q, self.r, self.g, ot_prob, epsilon=self.epsilon, use_danskin=use_danskin )
@property def geom(self) -> geometry.Geometry: # noqa: D102 """Linearized geometry.""" return _linearized_geometry(self.ot_prob, q=self.q, r=self.r, g=self.g) @property def a(self) -> jnp.ndarray: # noqa: D102 return self.ot_prob.a @property def b(self) -> jnp.ndarray: # noqa: D102 return self.ot_prob.b @property def n_iters(self) -> int: # noqa: D102 return jnp.sum(self.errors != -1) * self.inner_iterations @property def matrix(self) -> jnp.ndarray: """Transport matrix if it can be instantiated.""" return (self.q * self._inv_g) @ self.r.T
[docs] def apply(self, inputs: jnp.ndarray, axis: int = 0) -> jnp.ndarray: """Apply the transport to a array; axis=1 for its transpose.""" q, r = (self.q, self.r) if axis == 1 else (self.r, self.q) # for `axis=0`: (batch, m), (m, r), (r,), (r, n) return ((inputs @ r) * self._inv_g) @ q.T
[docs] def marginal(self, axis: int) -> jnp.ndarray: # noqa: D102 length = self.q.shape[0] if axis == 0 else self.r.shape[0] return self.apply(jnp.ones(length,), axis=axis)
[docs] def cost_at_geom(self, other_geom: geometry.Geometry) -> float: """Return OT cost for current solution, evaluated at any cost matrix.""" return jnp.sum(self.q * other_geom.apply_cost(self.r, axis=1) * self._inv_g)
[docs] def transport_cost_at_geom(self, other_geom: geometry.Geometry) -> float: """Return (by recomputing it) bare transport cost of current solution.""" return self.cost_at_geom(other_geom)
@property def primal_cost(self) -> float: """Return (by recomputing it) transport cost of current solution.""" geom_xx, geom_yy = self.ot_prob.geom_xx, self.ot_prob.geom_yy marginal_a = self.ot_prob.a if self.ot_prob.tau_a == 1.0 else self.q.sum(1) marginal_b = self.ot_prob.b if self.ot_prob.tau_b == 1.0 else self.r.sum(1) quad_cost = 0.5 * self.transport_cost_at_geom(other_geom=self.geom) quad_cost += jnp.vdot(geom_xx.apply_square_cost(marginal_a), marginal_a) quad_cost += jnp.vdot(geom_yy.apply_square_cost(marginal_b), marginal_b) if not self.ot_prob.is_fused: return quad_cost alpha = self.ot_prob.fused_penalty / (self.ot_prob.fused_penalty + 1.0) norm_g = jnp.linalg.norm(self.g, ord=1) lin_cost = self.cost_at_geom(self.ot_prob.geom_xy) return alpha * norm_g * lin_cost + (1.0 - alpha) * quad_cost @property def transport_mass(self) -> float: """Sum of transport matrix.""" return self.marginal(0).sum() @property def _inv_g(self) -> jnp.ndarray: return 1.0 / self.g
[docs] @jax.tree_util.register_pytree_node_class class LRGromovWasserstein(sinkhorn.Sinkhorn): r"""Low-rank Gromov-Wasserstein solver :cite:`scetbon:23`. The algorithm minimizes a non-convex problem. It therefore requires special care to initialization and convergence. Convergence is evaluated on successive evaluations of the objective. .. warning:: This solver only for the **unbalanced** case. Balanced case is implemented in :class:`~ott.solvers.quadratic.gromov_wasserstein.GromovWasserstein` and will be unified here in the future release. Args: rank: Rank constraint on the coupling to minimize the linear OT problem gamma: The (inverse of) gradient step size used by mirror descent. gamma_rescale: Whether to rescale :math:`\gamma` every iteration as described in :cite:`scetbon:22b`. epsilon: Entropic regularization added on top of low-rank problem. initializer: How to initialize the :math:`Q`, :math:`R` and :math:`g` factors. If :obj:`None`, use :class:`~ott.initializers.linear.initializers_lr.RandomInitializer`. lse_mode: Whether to run computations in LSE or kernel mode. inner_iterations: Number of inner iterations used by the algorithm before re-evaluating progress. min_iterations: The minimum number of low-rank Sinkhorn iterations carried out before the error is computed and monitored. max_iterations: The maximum number of low-rank Sinkhorn iterations. use_danskin: Use Danskin theorem to evaluate gradient of objective w.r.t. input parameters. Only `True` handled at this moment. implicit_diff: Whether to use implicit differentiation. Currently, only ``implicit_diff = False`` is implemented. progress_fn: callback function which gets called during the GW iterations, so the user can display the error at each iteration, e.g., using a progress bar. See :func:`~ott.utils.default_progress_fn` for a basic implementation. kwargs_dys: Keyword arguments passed to :meth:`dykstra_update_lse`, :meth:`dykstra_update_kernel` or one of the functions defined in :mod:`ott.solvers.linear`, depending on the ``lse_mode``. kwargs: Keyword arguments for :class:`~ott.solvers.linear.sinkhorn.Sinkhorn`. """ def __init__( self, rank: int, gamma: float = 10.0, gamma_rescale: bool = True, epsilon: float = 0.0, initializer: Optional[initializers_lr.LRInitializer] = None, lse_mode: bool = True, use_danskin: bool = True, implicit_diff: bool = False, inner_iterations: int = 2_000, min_iterations: int = 10_000, max_iterations: int = 100_000, kwargs_dys: Optional[Mapping[str, Any]] = None, progress_fn: Optional[ProgressFunction] = None, **kwargs: Any, ): assert not implicit_diff, "Implicit diff. not yet implemented." super().__init__( lse_mode=lse_mode, inner_iterations=inner_iterations, min_iterations=min_iterations, max_iterations=max_iterations, use_danskin=use_danskin, implicit_diff=implicit_diff, **kwargs ) self.rank = rank self.gamma = gamma self.gamma_rescale = gamma_rescale self.epsilon = epsilon self.initializer = initializers_lr.RandomInitializer( rank ) if initializer is None else initializer self.progress_fn = progress_fn # can be `None` self.kwargs_dys = {} if kwargs_dys is None else kwargs_dys def __call__( self, ot_prob: quadratic_problem.QuadraticProblem, init: Optional[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]] = None, rng: Optional[jax.Array] = None, **kwargs: Any, ) -> LRGWOutput: """Run the low-rank Gromov-Wasserstein solver. Args: ot_prob: Quadratic OT problem. init: Initial values of the low-rank factors: - :attr:`~ott.solvers.linear.sinkhorn_lr.LRGWOutput.q`. - :attr:`~ott.solvers.linear.sinkhorn_lr.LRGWOutput.r`. - :attr:`~ott.solvers.linear.sinkhorn_lr.LRGWOutput.g`. If :obj:`None`, use the initializer. rng: Random key for seeding. kwargs: Keyword arguments for the initializer. Returns: The low-rank GW output. """ rng = utils.default_prng_key(rng) rng_lrc, rng_init = jax.random.split(rng) if ot_prob._is_low_rank_convertible: ot_prob = ot_prob.to_low_rank(rng=rng_lrc) if init is None: init = self.initializer(ot_prob, rng=rng_init, **kwargs) return run(ot_prob, self, init) def _get_costs( self, ot_prob: quadratic_problem.QuadraticProblem, state: LRGWState, ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, float]: q, r, g = state.q, state.r, state.g log_q, log_r, log_g = mu.safe_log(q), mu.safe_log(r), mu.safe_log(g) inv_g = 1.0 / g[None, :] lin_geom = _linearized_geometry(ot_prob, q=q, r=r, g=g) tmp = lin_geom.apply_cost(r, axis=1) grad_q = tmp * inv_g if ot_prob.tau_a != 1.0: # unbalanced grad grad_q += 2.0 * ot_prob.geom_xx.apply_square_cost( q.sum(1), axis=1 )[:, None] grad_r = lin_geom.apply_cost(q, axis=0) * inv_g if ot_prob.tau_b != 1.0: # unbalanced grad grad_r += 2.0 * ot_prob.geom_yy.apply_square_cost( r.sum(1), axis=1 )[:, None] omega_quad = jnp.sum(q * tmp, axis=0) grad_g = -omega_quad / (g ** 2) if ot_prob.is_fused: alpha = ot_prob.fused_penalty / (ot_prob.fused_penalty + 1.0) norm_g = jnp.linalg.norm(g, ord=1) tmp = ot_prob.geom_xy.apply_cost(r, axis=1) lin_grad_q = tmp * inv_g * norm_g lin_grad_r = ot_prob.geom_xy.apply_cost(q) * inv_g * norm_g omega_lin = jnp.sum(q * tmp, axis=0) lin_grad_g = -omega_lin / (g ** 2) * norm_g + jnp.sum(q * tmp * inv_g) grad_q = alpha * lin_grad_q + (1.0 - alpha) * grad_q grad_r = alpha * lin_grad_r + (1.0 - alpha) * grad_r grad_g = alpha * lin_grad_g + (1.0 - alpha) * grad_g grad_q += self.epsilon * log_q grad_r += self.epsilon * log_r grad_g += self.epsilon * log_g if self.gamma_rescale: norm_q = jnp.max(jnp.abs(grad_q)) ** 2 norm_r = jnp.max(jnp.abs(grad_r)) ** 2 norm_g = jnp.max(jnp.abs(grad_g)) ** 2 gamma = self.gamma / jnp.max(jnp.array([norm_q, norm_r, norm_g])) else: gamma = self.gamma eps_factor = 1.0 / (self.epsilon * gamma + 1.0) gamma *= eps_factor c_q = -gamma * grad_q + eps_factor * log_q c_r = -gamma * grad_r + eps_factor * log_r c_g = -gamma * grad_g + eps_factor * log_g return c_q, c_r, c_g, gamma # TODO(michalk8): move to `lr_utils` when refactoring this the future
[docs] def dykstra_update_lse( self, c_q: jnp.ndarray, c_r: jnp.ndarray, h: jnp.ndarray, gamma: float, ot_prob: quadratic_problem.QuadraticProblem, min_entry_value: float = 1e-6, tolerance: float = 1e-3, min_iter: int = 0, inner_iter: int = 10, max_iter: int = 10000 ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Run Dykstra's algorithm.""" # shortcuts for problem's definition. r = self.rank n, m = ot_prob.geom_xx.shape[0], ot_prob.geom_yy.shape[0] loga, logb = jnp.log(ot_prob.a), jnp.log(ot_prob.b) h_old = h g1_old, g2_old = jnp.zeros(r), jnp.zeros(r) f1, f2 = jnp.zeros(n), jnp.zeros(m) w_gi, w_gp = jnp.zeros(r), jnp.zeros(r) w_q, w_r = jnp.zeros(r), jnp.zeros(r) err = jnp.inf state_inner = f1, f2, g1_old, g2_old, h_old, w_gi, w_gp, w_q, w_r, err constants = c_q, c_r, loga, logb def cond_fn( iteration: int, constants: Tuple[jnp.ndarray, ...], state_inner: Tuple[jnp.ndarray, ...] ) -> bool: del iteration, constants *_, err = state_inner return err > tolerance def _softm( f: jnp.ndarray, g: jnp.ndarray, c: jnp.ndarray, axis: int ) -> jnp.ndarray: return jsp.special.logsumexp( gamma * (f[:, None] + g[None, :] - c), axis=axis ) def body_fn( iteration: int, constants: Tuple[jnp.ndarray, ...], state_inner: Tuple[jnp.ndarray, ...], compute_error: bool ) -> Tuple[jnp.ndarray, ...]: # TODO(michalk8): in the future, use `NamedTuple` f1, f2, g1_old, g2_old, h_old, w_gi, w_gp, w_q, w_r, err = state_inner c_q, c_r, loga, logb = constants # First Projection f1 = jnp.where( jnp.isfinite(loga), (loga - _softm(f1, g1_old, c_q, axis=1)) / gamma + f1, loga ) f2 = jnp.where( jnp.isfinite(logb), (logb - _softm(f2, g2_old, c_r, axis=1)) / gamma + f2, logb ) h = h_old + w_gi h = jnp.maximum(jnp.log(min_entry_value) / gamma, h) w_gi += h_old - h h_old = h # Update couplings g_q = _softm(f1, g1_old, c_q, axis=0) g_r = _softm(f2, g2_old, c_r, axis=0) # Second Projection h = (1.0 / 3.0) * (h_old + w_gp + w_q + w_r) h += g_q / (3.0 * gamma) h += g_r / (3.0 * gamma) g1 = h + g1_old - g_q / gamma g2 = h + g2_old - g_r / gamma w_q = w_q + g1_old - g1 w_r = w_r + g2_old - g2 w_gp = h_old + w_gp - h q, r, _ = recompute_couplings(f1, g1, c_q, f2, g2, c_r, h, gamma) g1_old = g1 g2_old = g2 h_old = h err = jax.lax.cond( jnp.logical_and(compute_error, iteration >= min_iter), lambda: dykstra_solution_error(q, r, ot_prob, self.norm_error)[0], lambda: err ) return f1, f2, g1_old, g2_old, h_old, w_gi, w_gp, w_q, w_r, err def recompute_couplings( f1: jnp.ndarray, g1: jnp.ndarray, c_q: jnp.ndarray, f2: jnp.ndarray, g2: jnp.ndarray, c_r: jnp.ndarray, h: jnp.ndarray, gamma: float, ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: q = jnp.exp(gamma * (f1[:, None] + g1[None, :] - c_q)) r = jnp.exp(gamma * (f2[:, None] + g2[None, :] - c_r)) g = jnp.exp(gamma * h) return q, r, g state_inner = fixed_point_loop.fixpoint_iter_backprop( cond_fn, body_fn, min_iter, max_iter, inner_iter, constants, state_inner ) f1, f2, g1_old, g2_old, h_old, _, _, _, _, _ = state_inner return recompute_couplings(f1, g1_old, c_q, f2, g2_old, c_r, h_old, gamma)
[docs] def dykstra_update_kernel( self, k_q: jnp.ndarray, k_r: jnp.ndarray, k_g: jnp.ndarray, gamma: float, ot_prob: quadratic_problem.QuadraticProblem, min_entry_value: float = 1e-6, tolerance: float = 1e-3, min_iter: int = 0, inner_iter: int = 10, max_iter: int = 10000 ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Run Dykstra's algorithm.""" # shortcuts for problem's definition. del gamma rank = self.rank n, m = ot_prob.geom_xx.shape[0], ot_prob.geom_yy.shape[0] a, b = ot_prob.a, ot_prob.b supp_a, supp_b = a > 0, b > 0 g_old = k_g v1_old, v2_old = jnp.ones(rank), jnp.ones(rank) u1, u2 = jnp.ones(n), jnp.ones(m) q_gi, q_gp = jnp.ones(rank), jnp.ones(rank) q_q, q_r = jnp.ones(rank), jnp.ones(rank) err = jnp.inf state_inner = u1, u2, v1_old, v2_old, g_old, q_gi, q_gp, q_q, q_r, err constants = k_q, k_r, k_g, a, b def cond_fn( iteration: int, constants: Tuple[jnp.ndarray, ...], state_inner: Tuple[jnp.ndarray, ...] ) -> bool: del iteration, constants *_, err = state_inner return err > tolerance def body_fn( iteration: int, constants: Tuple[jnp.ndarray, ...], state_inner: Tuple[jnp.ndarray, ...], compute_error: bool ) -> Tuple[jnp.ndarray, ...]: # TODO(michalk8): in the future, use `NamedTuple` u1, u2, v1_old, v2_old, g_old, q_gi, q_gp, q_q, q_r, err = state_inner k_q, k_r, k_g, a, b = constants # First Projection u1 = jnp.where(supp_a, a / jnp.dot(k_q, v1_old), 0.0) u2 = jnp.where(supp_b, b / jnp.dot(k_r, v2_old), 0.0) g = jnp.maximum(min_entry_value, g_old * q_gi) q_gi = (g_old * q_gi) / g g_old = g # Second Projection v1_trans = jnp.dot(k_q.T, u1) v2_trans = jnp.dot(k_r.T, u2) g = (g_old * q_gp * v1_old * q_q * v1_trans * v2_old * q_r * v2_trans) ** (1 / 3) v1 = g / v1_trans v2 = g / v2_trans q_gp = (g_old * q_gp) / g q_q = (q_q * v1_old) / v1 q_r = (q_r * v2_old) / v2 v1_old = v1 v2_old = v2 g_old = g # Compute Couplings q, r, _ = recompute_couplings(u1, v1, k_q, u2, v2, k_r, g) err = jax.lax.cond( jnp.logical_and(compute_error, iteration >= min_iter), lambda: dykstra_solution_error(q, r, ot_prob, self.norm_error)[0], lambda: err ) return u1, u2, v1_old, v2_old, g_old, q_gi, q_gp, q_q, q_r, err def recompute_couplings( u1: jnp.ndarray, v1: jnp.ndarray, k_q: jnp.ndarray, u2: jnp.ndarray, v2: jnp.ndarray, k_r: jnp.ndarray, g: jnp.ndarray, ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: q = u1.reshape((-1, 1)) * k_q * v1.reshape((1, -1)) r = u2.reshape((-1, 1)) * k_r * v2.reshape((1, -1)) return q, r, g state_inner = fixed_point_loop.fixpoint_iter_backprop( cond_fn, body_fn, min_iter, max_iter, inner_iter, constants, state_inner ) u1, u2, v1_old, v2_old, g_old, _, _, _, _, _ = state_inner return recompute_couplings(u1, v1_old, k_q, u2, v2_old, k_r, g_old)
[docs] def lse_step( self, ot_prob: quadratic_problem.QuadraticProblem, state: LRGWState, iteration: int ) -> LRGWState: """Low-rank GW LSE update.""" c_q, c_r, c_g, gamma = self._get_costs(ot_prob, state) if ot_prob.is_balanced: c_q, c_r, h = c_q / -gamma, c_r / -gamma, c_g / gamma q, r, g = self.dykstra_update_lse( c_q, c_r, h, gamma, ot_prob, **self.kwargs_dys ) else: q, r, g = lr_utils.unbalanced_dykstra_lse( c_q, c_r, c_g, gamma, ot_prob, **self.kwargs_dys ) return state.set(q=q, g=g, r=r, gamma=gamma) #, (c_q, c_r, c_g)
[docs] def kernel_step( self, ot_prob: quadratic_problem.QuadraticProblem, state: LRGWState, iteration: int ) -> LRGWState: """Low-rank GW kernel update.""" c_q, c_r, c_g, gamma = self._get_costs(ot_prob, state) c_q, c_r, c_g = jnp.exp(c_q), jnp.exp(c_r), jnp.exp(c_g) if ot_prob.is_balanced: q, r, g = self.dykstra_update_kernel( c_q, c_r, c_g, gamma, ot_prob, **self.kwargs_dys ) else: q, r, g = lr_utils.unbalanced_dykstra_kernel( c_q, c_r, c_g, gamma, ot_prob, **self.kwargs_dys ) return state.set(q=q, g=g, r=r, gamma=gamma)
[docs] def one_iteration( self, ot_prob: quadratic_problem.QuadraticProblem, state: LRGWState, iteration: int, compute_error: bool ) -> LRGWState: """Carries out one low-rank GW iteration. Depending on lse_mode, these iterations can be either in: - log-space for numerical stability. - scaling space, using standard kernel-vector multiply operations. Args: ot_prob: the transport problem definition state: the current state. iteration: the current iteration of the GW outer loop. compute_error: flag to indicate this iteration computes/stores an error Returns: The updated state. """ previous_state = state it = iteration // self.inner_iterations if self.lse_mode: # In lse_mode, run additive updates. state = self.lse_step(ot_prob, state, iteration) else: state = self.kernel_step(ot_prob, state, iteration) # re-computes error if compute_error is True, else set it to inf. cost = jax.lax.cond( jnp.logical_and(compute_error, iteration >= self.min_iterations), lambda: state.reg_gw_cost(ot_prob, epsilon=self.epsilon), lambda: jnp.inf ) error = jax.lax.cond( iteration >= self.min_iterations, lambda: state.compute_error(previous_state), lambda: jnp.inf ) crossed_threshold = jnp.logical_or( state.crossed_threshold, jnp.logical_and( state.errors[it - 1] >= self.threshold, error < self.threshold ) ) state = state.set( costs=state.costs.at[it].set(cost), errors=state.errors.at[it].set(error), crossed_threshold=crossed_threshold, ) if self.progress_fn is not None: jax.debug.callback( self.progress_fn, (iteration, self.inner_iterations, self.max_iterations, state) ) return state
@property def norm_error(self) -> Tuple[int]: # noqa: D102 return self._norm_error,
[docs] def init_state( self, ot_prob: quadratic_problem.QuadraticProblem, init: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray] ) -> LRGWState: """Return the initial state of the loop.""" q, r, g = init return LRGWState( q=q, r=r, g=g, gamma=self.gamma, costs=-jnp.ones(self.outer_iterations), errors=-jnp.ones(self.outer_iterations), crossed_threshold=False, )
[docs] def output_from_state( self, ot_prob: quadratic_problem.QuadraticProblem, state: LRGWState ) -> LRGWOutput: """Create an output from a loop state. Args: ot_prob: the transport problem. state: GW state. Returns: A LRGWOutput. """ it = jnp.sum(state.errors != -1.0) * self.inner_iterations converged = self._converged(state, it) return LRGWOutput( q=state.q, r=state.r, g=state.g, ot_prob=ot_prob, costs=state.costs, errors=state.errors, epsilon=self.epsilon, inner_iterations=self.inner_iterations, converged=converged, )
def _converged(self, state: LRGWState, iteration: int) -> bool: def conv_crossed(prev_err: float, curr_err: float) -> bool: return jnp.logical_and( prev_err < self.threshold, curr_err < self.threshold ) def conv_not_crossed(prev_err: float, curr_err: float) -> bool: return jnp.logical_and(curr_err < prev_err, curr_err < self.threshold) # for convergence error, we consider 2 possibilities: # 1. we either crossed the convergence threshold; in this case we require # that the previous error was also below the threshold # 2. we haven't crossed the threshold; in this case, we can be below or # above the threshold: # if we're above, we wait until we reach the convergence threshold and # then, the above condition applies # if we're below and we improved w.r.t. the previous iteration, # we have converged; otherwise we continue, since we may be stuck # in a local minimum (e.g., during the initial iterations) it = iteration // self.inner_iterations return jax.lax.cond( state.crossed_threshold, conv_crossed, conv_not_crossed, state.errors[it - 2], state.errors[it - 1] ) def _diverged(self, state: LRGWState, iteration: int) -> bool: it = iteration // self.inner_iterations - 1 is_not_finite = jnp.logical_and( jnp.logical_not(jnp.isfinite(state.errors[it])), jnp.logical_not(jnp.isfinite(state.costs[it])) ) # `jnp.inf` is used if `it < self.min_iterations` return jnp.logical_and(it >= self.min_iterations, is_not_finite)
def run( ot_prob: quadratic_problem.QuadraticProblem, solver: LRGromovWasserstein, init: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], ) -> LRGWOutput: """Run loop of the solver, outputting a state upgraded to an output.""" out = sinkhorn.iterations(ot_prob, solver, init) out = out.set_cost( ot_prob, lse_mode=solver.lse_mode, use_danskin=solver.use_danskin ) return out.set(ot_prob=ot_prob) def dykstra_solution_error( q: jnp.ndarray, r: jnp.ndarray, ot_prob: quadratic_problem.QuadraticProblem, norm_error: Tuple[int, ...] ) -> jnp.ndarray: """Compute solution error. Since only balanced case is available for LR, this is marginal deviation. Args: q: first factor of solution. r: second factor of solution. ot_prob: linear problem. norm_error: int, p-norm used to compute error. Returns: one or possibly many numbers quantifying deviation to true marginals. """ norm_error = jnp.array(norm_error) # Update the error err = jnp.sum( jnp.abs(jnp.sum(q, axis=1) - ot_prob.a) ** norm_error[:, None], axis=1 ) ** (1.0 / norm_error) err += jnp.sum( jnp.abs(jnp.sum(r, axis=1) - ot_prob.b) ** norm_error[:, None], axis=1 ) ** (1.0 / norm_error) err += jnp.sum( jnp.abs(jnp.sum(q, axis=0) - jnp.sum(r, axis=0)) ** norm_error[:, None], axis=1 ) ** (1.0 / norm_error) return err def _linearized_geometry( prob: quadratic_problem.QuadraticProblem, *, q: jnp.ndarray, r: jnp.ndarray, g: jnp.ndarray, ) -> low_rank.LRCGeometry: inv_sqrt_g = 1.0 / jnp.sqrt(g[None, :]) # TODO(michalk8): below is for squared loss, handle KL loss in the future; # will need to be updated in many other places as well tmp1 = -4.0 * prob.geom_xx.apply_cost(q, axis=1) * inv_sqrt_g tmp2 = prob.geom_yy.apply_cost(r, axis=1) * inv_sqrt_g return low_rank.LRCGeometry(tmp1, tmp2)