Source code for ott.geometry.distrib_costs
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
import jax.numpy as jnp
import jax.tree_util as jtu
from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import univariate
__all__ = [
"UnivariateWasserstein",
]
[docs]
@jtu.register_pytree_node_class
class UnivariateWasserstein(costs.CostFn):
"""1D Wasserstein cost for two 1D distributions.
This ground cost between considers vectors as a family of values.
The Wasserstein distance between them is the 1D OT cost, using a user-defined
ground cost.
Args:
ground_cost: Cost used to compute the 1D optimal transport between vector,
should be a translation-invariant (TI) cost for correctness.
If :obj:`None`, defaults to :class:`~ott.geometry.costs.SqEuclidean`.
solver: 1D optimal transport solver.
kwargs: Arguments passed on when calling the
:class:`~ott.solvers.linear.univariate.UnivariateSolver`. May include
random key, or specific instructions to subsample or compute using
quantiles.
"""
def __init__(
self,
ground_cost: Optional[costs.TICost] = None,
solver: Optional[univariate.UnivariateSolver] = None,
**kwargs: Any
):
super().__init__()
self.ground_cost = (
costs.SqEuclidean() if ground_cost is None else ground_cost
)
self._solver = univariate.UnivariateSolver() if solver is None else solver
self._kwargs_solve = kwargs
# ensure transport solutions are neither computed nor stored
self._kwargs_solve["return_transport"] = False
[docs]
def pairwise(self, x: jnp.ndarray, y: jnp.ndarray) -> float:
"""Wasserstein distance between :math:`x` and :math:`y` seen as a 1D dist.
Args:
x: Array of shape ``[n,]``.
y: Array of shape ``[m,]``.
Returns:
The transport cost.
"""
out = self._solver(
linear_problem.LinearProblem(
pointcloud.PointCloud(
x[:, None], y[:, None], cost_fn=self.ground_cost
)
), **self._kwargs_solve
)
return jnp.squeeze(out.ot_costs)
def tree_flatten(self): # noqa: D102
return (self.ground_cost,), (self._solver, self._kwargs_solve)
@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
ground_cost, = children
solver, solve_kwargs = aux_data
return cls(ground_cost, solver, **solve_kwargs)