Source code for ott.problems.linear.semidiscrete_linear_problem

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional

import jax
import jax.numpy as jnp
import jax.tree_util as jtu

from ott.geometry import semidiscrete_pointcloud
from ott.problems.linear import linear_problem

__all__ = ["SemidiscreteLinearProblem"]


[docs] @jtu.register_pytree_node_class class SemidiscreteLinearProblem: """Semidiscrete linear OT problem. Instances of this problem can be sampled using the :meth:`sample` method. Args: geom: Semidiscrete point cloud geometry. b: The second marginal. If :obj:`None`, it will be uniform. tau_b: If :math:`< 1`, defines how much unbalanced the problem is on the second marginal. Currently not implemented. """ def __init__( self, geom: semidiscrete_pointcloud.SemidiscretePointCloud, b: Optional[jax.Array] = None, tau_b: float = 1.0, ): assert tau_b == 1.0, "Unbalanced semidiscrete problem is not supported." self.geom = geom self._b = b self.tau_b = tau_b
[docs] def sample( self, rng: jax.Array, num_samples: int, *, epsilon: Optional[float] = None, ) -> linear_problem.LinearProblem: """Sample a linear OT problem. Args: rng: Random key used for seeding. num_samples: Number of samples. epsilon: Epsilon regularization. If :obj:`None`, use :attr:`epsilon`. Returns: The sampled linear problem. """ if epsilon is None: epsilon = self.epsilon geom = self.geom.sample(rng, num_samples, epsilon=epsilon) return linear_problem.LinearProblem( geom, a=None, b=self._b, tau_a=1.0, tau_b=self.tau_b )
[docs] def potential_fn_from_dual_vec( self, g: jax.Array, *, epsilon: Optional[float] = None ) -> Callable[[jax.Array], jax.Array]: r"""Get potential function from a dual vector using the :term:`c-transform`. Args: g: Potential vector :math:`\mathbb{g}` of shape ``[m,]``. epsilon: Epsilon regularization. If :obj:`None`, use in the :attr:`geom`. Returns: The dual potential function :math:`f`. """ # `potential_fn_from_dual_vec` accesses only necessary properties of the # problem/geometry, so we can pass the semidiscrete point cloud prob = linear_problem.LinearProblem(self.geom, b=self.b) return prob.potential_fn_from_dual_vec(g, epsilon=epsilon, axis=1)
@property def b(self) -> jnp.ndarray: """Second marginal.""" if self._b is not None: return self._b _, m = self.geom.shape return jnp.full((m,), fill_value=1.0 / m, dtype=self.geom.y.dtype) @property def epsilon(self) -> jax.Array: """Entropic regularization.""" return self.geom.epsilon def tree_flatten(self): # noqa: D102 return (self.geom, self._b), {"tau_b": self.tau_b} @classmethod def tree_unflatten( # noqa: D102 cls, aux_data, children ) -> "SemidiscreteLinearProblem": return cls(*children, **aux_data)