Source code for ott.tools.unreg

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple

import jax.numpy as jnp

from optax import assignment

from ott.geometry import costs, geometry, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import semidiscrete

__all__ = ["hungarian", "wassdis_p"]


[docs] def hungarian( geom: geometry.Geometry ) -> Tuple[jnp.ndarray, semidiscrete.HardAssignmentOutput]: """Solve matching problem using the :term:`Hungarian algorithm`. Uses the implementation from :mod:`optax`. Args: geom: Geometry object with square (shape ``[n, n]``) :attr:`~ott.geometry.geometry.Geometry.cost_matrix`. Returns: The value of the unregularized OT problem, along with an output object listing relevant information on outputs. """ n, m = geom.shape assert n == m, f"Hungarian can only match same # of points, got {n} and {m}." cost_matrix = geom.cost_matrix i, j = assignment.hungarian_algorithm(cost_matrix) prob = linear_problem.LinearProblem(geom) out = semidiscrete.HardAssignmentOutput( prob, paired_indices=jnp.stack([i, j]) ) transport_cost = cost_matrix[i, j].sum() / n return transport_cost, out
[docs] def wassdis_p(x: jnp.ndarray, y: jnp.ndarray, *, p: float = 2.0) -> float: """Compute the :term:`Wasserstein distance`, uses :term:`Hungarian algorithm`. Uses :func:`hungarian` to solve the :term:`optimal matching problem` between two point clouds of the same size, to compute a :term:`Wasserstein distance` estimator. Note: At the moment, only supports point clouds of the same size to be easily cast as an optimal matching problem. Args: x: ``[n, d]`` point cloud. y: ``[n, d]`` point cloud of the same size. p: order of the Wasserstein distance, non-negative float. Returns: The `p`-Wasserstein distance between these point clouds. """ geom = pointcloud.PointCloud(x, y, cost_fn=costs.EuclideanP(p=p)) cost, _ = hungarian(geom) return cost ** (1.0 / p)