Source code for ott.math._lbfgs
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Tuple
import jax
import jax.numpy as jnp
import optax
__all__ = ["lbfgs"]
# see https://optax.readthedocs.io/en/stable/_collections/examples/lbfgs.html
def run_opt(
opt: optax.GradientTransformationExtraArgs,
x_init: jnp.ndarray,
fun: Callable[[jnp.ndarray], jnp.ndarray],
max_iter: int,
tol: float,
) -> Tuple[jnp.ndarray, optax.OptState]:
"""Runs an optimization algorithm on a function.
Args:
opt: An instance of an optax optimizer.
x_init: Initial point to start optimization.
fun: The function to minimize.
max_iter: Maximum number of iterations.
tol: Tolerance for convergence, measured as the norm of the gradient.
Returns:
Final optimization variable obtained after running the optimization.
"""
value_and_grad_fun = optax.value_and_grad_from_state(fun)
def step(carry):
params, state = carry
value, grad = value_and_grad_fun(params, state=state)
updates, state = opt.update(
grad, state, params, value=value, grad=grad, value_fn=fun
)
params = optax.apply_updates(params, updates)
return params, state
def continuing_criterion(carry):
_, state = carry
iter_num = optax.tree.get(state, "count")
grad = optax.tree.get(state, "grad")
err = optax.tree.norm(grad)
return (iter_num == 0) | ((iter_num < max_iter) & (err >= tol))
init_carry = (x_init, opt.init(x_init))
final_params, final_state = jax.lax.while_loop(
continuing_criterion, step, init_carry
)
return final_params, final_state
[docs]
def lbfgs(
fun: Callable[[jnp.ndarray], jnp.ndarray],
x_init: jnp.ndarray,
max_iter: int = 100,
tol: float = 1e-4,
**kwargs: Any,
) -> Tuple[jnp.ndarray, optax.OptState]:
"""Runs optax's L-BFGS optimization on function.
Args:
fun: The function to minimize.
x_init: Initial point to start optimization.
max_iter: Maximum number of iterations.
tol: Tolerance for convergence.
kwargs: Keyword arguments for :func:`optax.lbfgs`.
Returns:
Final optimization variable obtained after running L-BFGS and state.
"""
opt = optax.lbfgs(**kwargs)
return run_opt(opt, x_init, fun, max_iter=max_iter, tol=tol)