Source code for ott.neural.networks.layers.conjugate
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import Callable, Literal, NamedTuple, Optional
import jax.numpy as jnp
from jaxopt import LBFGS
from ott import utils
__all__ = [
"ConjugateResults",
"FenchelConjugateSolver",
"FenchelConjugateLBFGS",
"DEFAULT_CONJUGATE_SOLVER",
]
[docs]
class ConjugateResults(NamedTuple):
r"""Holds the results of numerically conjugating a function.
Args:
val: the conjugate value, i.e., :math:`f^\star(y)`
grad: the gradient, i.e., :math:`\nabla f^\star(y)`
num_iter: the number of iterations taken by the solver
"""
val: float
grad: jnp.ndarray
num_iter: int
[docs]
class FenchelConjugateSolver(abc.ABC):
r"""Abstract conjugate solver class.
Given a function :math:`f`, numerically estimate the Fenchel conjugate
:math:`f^\star(y) := -\inf_{x\in\mathbb{R}^n} f(x)-\langle x, y\rangle`.
"""
[docs]
@abc.abstractmethod
def solve(
self,
f: Callable[[jnp.ndarray], jnp.ndarray],
y: jnp.ndarray,
x_init: Optional[jnp.ndarray] = None
) -> ConjugateResults:
"""Solve for the conjugate.
Args:
f: function to conjugate
y: point to conjugate
x_init: initial point to search over
Returns:
The solution to the conjugation.
"""
[docs]
@utils.register_pytree_node
class FenchelConjugateLBFGS(FenchelConjugateSolver):
"""Solve for the conjugate using :class:`~jaxopt.LBFGS`.
Args:
gtol: gradient tolerance
max_iter: maximum number of iterations
max_linesearch_iter: maximum number of line search iterations
linesearch_type: type of line search
linesearch_init: strategy for line search initialization
increase_factor: factor by which to increase the step size during
the line search
"""
gtol: float = 1e-3
max_iter: int = 10
max_linesearch_iter: int = 10
linesearch_type: Literal["zoom", "backtracking",
"hager-zhang"] = "backtracking"
linesearch_init: Literal["increase", "max", "current"] = "increase"
increase_factor: float = 1.5
[docs]
def solve( # noqa: D102
self,
f: Callable[[jnp.ndarray], jnp.ndarray],
y: jnp.ndarray,
x_init: Optional[jnp.array] = None
) -> ConjugateResults:
assert y.ndim == 1, y.ndim
solver = LBFGS(
fun=lambda x: f(x) - x.dot(y),
tol=self.gtol,
maxiter=self.max_iter,
linesearch=self.linesearch_type,
linesearch_init=self.linesearch_init,
increase_factor=self.increase_factor,
implicit_diff=False,
unroll=False
)
out = solver.run(y if x_init is None else x_init)
return ConjugateResults(
val=-out.state.value, grad=out.params, num_iter=out.state.iter_num
)
DEFAULT_CONJUGATE_SOLVER = FenchelConjugateLBFGS(
gtol=1e-5,
max_iter=20,
max_linesearch_iter=20,
linesearch_type="backtracking",
)