Source code for ott.solvers.linear.lineax_implicit

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, TypeVar

import equinox as eqx
import lineax as lx
from jaxtyping import Array, Float, PyTree

import jax
import jax.numpy as jnp
import jax.tree_util as jtu

_T = TypeVar("_T")
_FlatPyTree = tuple[list[_T], jtu.PyTreeDef]

__all__ = ["CustomTransposeLinearOperator", "solve_lineax"]


class CustomTransposeLinearOperator(lx.FunctionLinearOperator):
  """Implement a linear operator that can specify its transpose directly."""
  fn: Callable[[PyTree[Float[Array, "..."]]], PyTree[Float[Array, "..."]]]
  fn_t: Callable[[PyTree[Float[Array, "..."]]], PyTree[Float[Array, "..."]]]
  input_structure: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.static_field()
  input_structure_t: _FlatPyTree[jax.ShapeDtypeStruct] = eqx.static_field()
  tags: frozenset[object]

  def __init__(self, fn, fn_t, input_structure, input_structure_t, tags=()):
    super().__init__(fn, input_structure, tags)
    self.fn_t = eqx.filter_closure_convert(fn_t, input_structure_t)
    self.input_structure_t = input_structure_t

  def transpose(self):
    """Provide custom transposition operator from function."""
    return lx.FunctionLinearOperator(self.fn_t, self.input_structure_t)


[docs] def solve_lineax( lin: Callable, b: jnp.ndarray, lin_t: Optional[Callable] = None, symmetric: bool = False, nonsym_solver: Optional[lx.AbstractLinearSolver] = None, ridge_identity: float = 0.0, ridge_kernel: float = 0.0, **kwargs: Any ) -> jnp.ndarray: """Wrapper around lineax solvers. Args: lin: Linear operator b: vector. Returned `x` is such that `lin(x)=b` lin_t: Linear operator, corresponding to transpose of `lin`. symmetric: whether `lin` is symmetric. nonsym_solver: solver used when handling non-symmetric cases. Note that :class:`~lineax.CG` is used by default in the symmetric case. ridge_kernel: promotes zero-sum solutions. Only use if `tau_a = tau_b = 1.0` ridge_identity: handles rank deficient transport matrices (this happens typically when rows/cols in cost/kernel matrices are collinear, or, equivalently when two points from either measure are close). kwargs: arguments passed to :class:`~lineax.AbstractLinearSolver` linear solver. """ input_structure = jax.eval_shape(lambda: b) kwargs.setdefault("rtol", 1e-6) kwargs.setdefault("atol", 1e-6) if ridge_kernel > 0.0 or ridge_identity > 0.0: lin_reg = lambda x: lin(x) + ridge_kernel * jnp.sum(x) + ridge_identity * x lin_t_reg = lambda x: lin_t(x) + ridge_kernel * jnp.sum( x ) + ridge_identity * x else: lin_reg, lin_t_reg = lin, lin_t if symmetric: solver = lx.CG(**kwargs) fn_operator = lx.FunctionLinearOperator( lin_reg, input_structure, tags=lx.positive_semidefinite_tag ) return lx.linear_solve(fn_operator, b, solver).value # In the non-symmetric case, use NormalCG by default, but consider # user defined choice of alternative lx solver. solver_type = lx.NormalCG if nonsym_solver is None else nonsym_solver solver = solver_type(**kwargs) fn_operator = CustomTransposeLinearOperator( lin_reg, lin_t_reg, input_structure, input_structure ) return lx.linear_solve(fn_operator, b, solver).value