Source code for ott.geometry.epsilon_scheduler
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import jax.numpy as jnp
import jax.tree_util as jtu
__all__ = ["Epsilon", "DEFAULT_EPSILON_SCALE"]
#: Scaling applied to statistic (mean/std) of cost to compute default epsilon.
DEFAULT_EPSILON_SCALE = 0.05
[docs]
@jtu.register_pytree_node_class
class Epsilon:
r"""Scheduler class for the regularization parameter epsilon.
An epsilon scheduler outputs a regularization strength, to be used by the
:term:`Sinkhorn algorithm` or variant, at any iteration count. That value is
either the final, targeted regularization, or one that is larger, obtained by
geometric decay of an initial multiplier.
Args:
target: The epsilon regularizer that is targeted.
init: Initial value when using epsilon scheduling, understood as a multiple
of the ``target``, following :math:`\text{init} \text{decay}^{\text{it}}`.
decay: Geometric decay factor, :math:`\leq 1`.
"""
def __init__(self, target: jnp.array, init: float = 1.0, decay: float = 1.0):
assert decay <= 1.0, f"Decay must be <= 1, found {decay}."
self.target = target
self.init = init
self.decay = decay
def __call__(self, it: Optional[int]) -> jnp.array:
"""Intermediate regularizer value at a given iteration number.
Args:
it: Current iteration. If :obj:`None`, return :attr:`target`.
Returns:
The epsilon value at the iteration.
"""
if it is None:
return self.target
# the multiple is either 1.0 or a larger init value that is decayed.
multiple = jnp.maximum(self.init * (self.decay ** it), 1.0)
return multiple * self.target
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(target={self.target:.4f}, "
f"init={self.init:.4f}, decay={self.decay:.4f})"
)
def tree_flatten(self): # noqa: D102
return (self.target,), {"init": self.init, "decay": self.decay}
@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
return cls(*children, **aux_data)