Source code for ott.math._legendre
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional
import jax
import jax.numpy as jnp
from ott.math import _lbfgs as lbfgs
__all__ = ["legendre"]
[docs]
def legendre(
fun: Callable[[jnp.ndarray], jnp.ndarray],
**kwargs: Any,
) -> Callable[[jnp.ndarray, Optional[jnp.ndarray], Any], jnp.ndarray]:
"""Legendre (Fenchel) transform of a function.
The solution is computed numerically using L-BFGS.
Args:
fun: A function to be transformed, must be convex for the transform
to be properly defined.
kwargs: Keyword arguments for :func:`~ott.math.lbfgs`, e.g. maximal
iterations ``max_iters``, convergence tolerance ``tol`` or
:func:`optax.lbfgs` arguments.
Returns:
A function that computes numerically the Legendre transform of
the ``fun`` at a given point.
"""
def fun_star(
x: jnp.ndarray,
x_init: Optional[jnp.ndarray] = None,
) -> float:
"""Runs optimization to compute the Legendre transform of ``fun`` at ``x``.
Args:
x: Array of shape ``[d,]`` where to evaluate the function.
x_init: Initialization for optimization, of the same size of ``x``.
If :obj:`None`, use ``x``.
Returns:
The Legendre transform of the ``fun`` evaluated at ``x``.
"""
x_init = x if x_init is None else x_init
def mod_fun(z: jnp.ndarray) -> float:
"""Conjugate maximizes <x,z> - fun(z), here minimize fun(z) - <x,z>."""
return fun(z) - jnp.dot(x, z)
z, _ = lbfgs.lbfgs(fun=mod_fun, x_init=x_init, **kwargs)
# Flip sign again to revert to maximization convention, stop gradient.
return -mod_fun(jax.lax.stop_gradient(z))
return fun_star