Source code for ott.geometry.epsilon_scheduler

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional

import jax
import jax.numpy as jnp

__all__ = ["Epsilon"]


[docs] @jax.tree_util.register_pytree_node_class class Epsilon: """Scheduler class for the regularization parameter epsilon. An epsilon scheduler outputs a regularization strength, to be used by in a Sinkhorn-type algorithm, at any iteration count. That value is either the final, targeted regularization, or one that is larger, obtained by geometric decay of an initial value that is larger than the intended target. Concretely, the value returned by such a scheduler will consider first the max between ``target`` and ``init * target * decay ** iteration``. If the ``scale_epsilon`` parameter is provided, that value is used to multiply the max computed previously by ``scale_epsilon``. Args: target: the epsilon regularizer that is targeted. If ``None``, use :math:`0.05`. scale_epsilon: if passed, used to multiply the regularizer, to rescale it. If ``None``, use :math:`1`. init: initial value when using epsilon scheduling, understood as multiple of target value. if passed, ``int * decay ** iteration`` will be used to rescale target. decay: geometric decay factor, :math:`<1`. """ def __init__( self, target: Optional[float] = None, scale_epsilon: Optional[float] = None, init: float = 1.0, decay: float = 1.0 ): self._target_init = target self._scale_epsilon = scale_epsilon self._init = init self._decay = decay @property def target(self) -> float: """Return the final regularizer value of scheduler.""" target = 5e-2 if self._target_init is None else self._target_init scale = 1.0 if self._scale_epsilon is None else self._scale_epsilon return scale * target
[docs] def at(self, iteration: Optional[int] = 1) -> float: """Return (intermediate) regularizer value at a given iteration.""" if iteration is None: return self.target # check the decay is smaller than 1.0. decay = jnp.minimum(self._decay, 1.0) # the multiple is either 1.0 or a larger init value that is decayed. multiple = jnp.maximum(self._init * (decay ** iteration), 1.0) return multiple * self.target
[docs] def done(self, eps: float) -> bool: """Return whether the scheduler is done at a given value.""" return eps == self.target
[docs] def done_at(self, iteration: Optional[int]) -> bool: """Return whether the scheduler is done at a given iteration.""" return self.done(self.at(iteration))
[docs] def set(self, **kwargs: Any) -> "Epsilon": """Return a copy of self, with potential overwrites.""" kwargs = { "target": self._target_init, "scale_epsilon": self._scale_epsilon, "init": self._init, "decay": self._decay, **kwargs } return Epsilon(**kwargs)
def tree_flatten(self): # noqa: D102 return ( self._target_init, self._scale_epsilon, self._init, self._decay ), None @classmethod def tree_unflatten(cls, aux_data, children): # noqa: D102 del aux_data return cls(*children)