Source code for ott.solvers.linear.acceleration

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

import jax
import jax.numpy as jnp

from ott import utils

if TYPE_CHECKING:
  from ott.problems.linear import linear_problem
  from ott.solvers.linear import sinkhorn

__all__ = ["AndersonAcceleration", "Momentum"]


[docs] @utils.register_pytree_node class AndersonAcceleration: """Implements Anderson acceleration for Sinkhorn.""" # TODO(michalk8): use memory=0 as no Anderson acceleration? memory: int = 2 # Number of iterates considered to form interpolation. refresh_every: int = 1 # Recompute interpolation periodically. ridge_identity: float = 1e-2 # Ridge used in the linear system.
[docs] def extrapolation(self, xs: jnp.ndarray, fxs: jnp.ndarray) -> jnp.ndarray: """Compute Anderson extrapolation from past observations.""" # Remove -inf values to instantiate quadratic problem. All others # remain since they might be caused by a valid issue. fxs_clean = jnp.nan_to_num(fxs, nan=jnp.nan, posinf=jnp.inf, neginf=0.0) xs_clean = jnp.nan_to_num(xs, nan=jnp.nan, posinf=jnp.inf, neginf=0.0) residuals = fxs_clean - xs_clean gram_matrix = jnp.matmul(residuals.T, residuals) gram_matrix /= jnp.linalg.norm(gram_matrix) # Solve linear system to obtain weights weights = jax.scipy.sparse.linalg.cg( gram_matrix + self.ridge_identity * jnp.eye(xs.shape[1]), jnp.ones(xs.shape[1]) )[0] weights /= jnp.sum(weights) # Recover linear combination and return it with NaN (caused # by 0 weights leading to -jnp.inf potentials, mixed with weights # coefficients of different signs), disambiguated to -inf. combination = jnp.sum(fxs * weights[None, :], axis=1) return jnp.where(jnp.isfinite(combination), combination, -jnp.inf)
[docs] def update( self, state: "sinkhorn.SinkhornState", iteration: int, prob: "linear_problem.LinearProblem", lse_mode: bool ) -> "sinkhorn.SinkhornState": """Anderson acceleration update. When using Anderson acceleration, first update the dual variable f_u with previous updates (if iteration count sufficiently large), then record new iterations in array. Anderson acceleration always happens in potentials (not scalings) space, regardless of the lse_mode setting. If the iteration count is large enough the update below will output a potential variable. Args: state: Sinkhorn state. iteration: the current iteration. prob: linear OT problem. lse_mode: whether to compute in log-sum-exp or in scalings. Returns: A potential variable. """ geom = prob.geom trigger_update = jnp.logical_and( iteration > self.memory, iteration % self.refresh_every == 0 ) fu = jnp.where( trigger_update, self.extrapolation(state.old_fus, state.old_mapped_fus), state.fu ) # If the interpolation was triggered, we store it in memory # Otherwise we add the previous value (converting it to potential form if # it was initially stored in scaling form). old_fus = jnp.where( trigger_update, jnp.concatenate((state.old_fus[:, 1:], fu[:, None]), axis=1), jnp.concatenate(( state.old_fus[:, 1:], (fu if lse_mode else geom.potential_from_scaling(fu))[:, None] ), axis=1) ) # If update was triggered, ensure a scaling is returned, since the result # from the extrapolation was outputted in potential form. fu = jnp.where( trigger_update, fu if lse_mode else geom.scaling_from_potential(fu), fu ) return state.set(potentials=(fu, state.gv), old_fus=old_fus)
[docs] def init_maps( self, pb, state: "sinkhorn.SinkhornState" ) -> "sinkhorn.SinkhornState": """Initialize log matrix used in Anderson acceleration with *NaN* values.""" fus = jnp.ones((pb.geom.shape[0], self.memory)) * jnp.nan return state.set(old_fus=fus, old_mapped_fus=fus)
[docs] def update_history( self, state: "sinkhorn.SinkhornState", pb, lse_mode: bool ) -> "sinkhorn.SinkhornState": """Update history of mapped dual variables.""" f = state.fu if lse_mode else pb.geom.potential_from_scaling(state.fu) mapped = jnp.concatenate((state.old_mapped_fus[:, 1:], f[:, None]), axis=1) return state.set(old_mapped_fus=mapped)
[docs] @utils.register_pytree_node class Momentum: """Momentum for Sinkhorn updates. Can be either constant :cite:`thibault:21` or adaptive :cite:`lehmann:21`. """ start: int = 0 error_threshold: float = jnp.inf value: float = 1.0 inner_iterations: int = 1
[docs] def weight(self, state: "sinkhorn.SinkhornState", iteration: int) -> float: """Compute momentum term if needed, using previously seen errors.""" if self.start == 0: return self.value idx = self.start // self.inner_iterations return jax.lax.cond( jnp.logical_and( iteration >= self.start, state.errors[idx - 1, -1] < self.error_threshold ), lambda state: self.lehmann(state), lambda state: self.value, state )
[docs] def lehmann(self, state: "sinkhorn.SinkhornState") -> float: """Momentum formula :cite:`lehmann:21`, eq. 5.""" idx = self.start // self.inner_iterations error_ratio = jnp.minimum( state.errors[idx - 1, -1] / state.errors[idx - 2, -1], 0.99 ) power = 1.0 / self.inner_iterations return 2.0 / (1.0 + jnp.sqrt(1.0 - error_ratio ** power))
def __call__( # noqa: D102 self, weight: float, value: jnp.ndarray, new_value: jnp.ndarray, lse_mode: bool = True ) -> jnp.ndarray: if lse_mode: value = jnp.where(jnp.isfinite(value), value, 0.0) return (1.0 - weight) * value + weight * new_value value = jnp.where(value > 0.0, value, 1.0) return value ** (1.0 - weight) * new_value ** weight