Source code for ott.geometry.distrib_costs
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, Optional
import jax.numpy as jnp
import jax.tree_util as jtu
from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import univariate
__all__ = ["UnivariateWasserstein"]
[docs]
@jtu.register_pytree_node_class
class UnivariateWasserstein(costs.CostFn):
"""1D Wasserstein cost for two 1D distributions.
This ground cost between considers vectors as a family of values.
The Wasserstein distance between them is the 1D OT cost, using a user-defined
ground cost.
Args:
solve_fn: 1D optimal transport solver, e.g.,
:func:`~ott.solvers.linear.univariate.uniform_distance`.
ground_cost: Cost used to compute the 1D optimal transport between vectors.
Should be a translation-invariant (TI) cost for correctness.
If :obj:`None`, defaults to :class:`~ott.geometry.costs.SqEuclidean`.
"""
def __init__(
self,
solve_fn: Callable[[linear_problem.LinearProblem],
univariate.UnivariateOutput],
ground_cost: Optional[costs.TICost] = None,
):
super().__init__()
self.ground_cost = (
costs.SqEuclidean() if ground_cost is None else ground_cost
)
self._solve_fn = solve_fn
def __call__(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Wasserstein distance between :math:`x` and :math:`y` seen as a 1D dist.
Args:
x: Array of shape ``[n,]``.
y: Array of shape ``[m,]``.
Returns:
The transport cost.
"""
geom = pointcloud.PointCloud(
x[:, None], y[:, None], cost_fn=self.ground_cost
)
prob = linear_problem.LinearProblem(geom)
out = self._solve_fn(prob)
return jnp.squeeze(out.ot_costs)
def tree_flatten(self): # noqa: D102
return (self.ground_cost,), (self._solve_fn,)
@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
return cls(solve_fn=aux_data[0], ground_cost=children[0])