Source code for ott.solvers.was_solver

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, Sequence, Tuple, Union

import jax
import jax.numpy as jnp

from ott.solvers.linear import sinkhorn, sinkhorn_lr

if TYPE_CHECKING:
  from ott.solvers.linear import continuous_barycenter

__all__ = ["WassersteinSolver"]

State = Union[sinkhorn.SinkhornState, sinkhorn_lr.LRSinkhornState,
              "continuous_barycenter.FreeBarycenterState"]


[docs] @jax.tree_util.register_pytree_node_class class WassersteinSolver: """A generic solver for problems that use a linear problem in inner loop.""" def __init__( self, linear_solver: Union["sinkhorn.Sinkhorn", "sinkhorn_lr.LRSinkhorn"], threshold: float = 1e-3, min_iterations: int = 5, max_iterations: int = 50, store_inner_errors: bool = False, ): self.linear_solver = linear_solver self.min_iterations = min_iterations self.max_iterations = max_iterations self.threshold = threshold self.store_inner_errors = store_inner_errors @property def rank(self) -> int: """Rank of the linear OT solver.""" return self.linear_solver.rank if self.is_low_rank else -1 @property def is_low_rank(self) -> bool: """Whether the solver is low-rank.""" return isinstance(self.linear_solver, sinkhorn_lr.LRSinkhorn) def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]: # noqa: D102 return ([self.linear_solver, self.threshold], { "min_iterations": self.min_iterations, "max_iterations": self.max_iterations, "store_inner_errors": self.store_inner_errors, }) @classmethod def tree_unflatten( # noqa: D102 cls, aux_data: Dict[str, Any], children: Sequence[Any] ) -> "WassersteinSolver": return cls(*children, **aux_data) def _converged(self, state: State, iteration: int) -> bool: costs, i, tol = state.costs, iteration, self.threshold return jnp.logical_and( i >= 2, jnp.isclose(costs[i - 2], costs[i - 1], rtol=tol) ) def _diverged(self, state: State, iteration: int) -> bool: return jnp.logical_not(jnp.isfinite(state.costs[iteration - 1])) def _continue(self, state: State, iteration: int) -> bool: """Continue while not(converged) and not(diverged).""" return jnp.logical_or( iteration <= 2, jnp.logical_and( jnp.logical_not(self._diverged(state, iteration)), jnp.logical_not(self._converged(state, iteration)) ) )