# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Any, Callable, Literal, Optional, Tuple, Union
import jax
import jax.numpy as jnp
from ott.geometry import costs, geometry, low_rank
from ott.math import utils as mu
__all__ = ["PointCloud"]
[docs]
@jax.tree_util.register_pytree_node_class
class PointCloud(geometry.Geometry):
"""Defines geometry for 2 point clouds (possibly 1 vs itself).
Creates a geometry, specifying a cost function passed as CostFn type object.
When the number of points is large, setting the ``batch_size`` flag implies
that cost and kernel matrices used to update potentials or scalings
will be recomputed on the fly, rather than stored in memory. More precisely,
when setting ``batch_size``, the cost function will be partially cached by
storing norm values for each point in both point clouds, but the pairwise cost
function evaluations won't be.
Args:
x : n x d array of n d-dimensional vectors
y : m x d array of m d-dimensional vectors. If `None`, use ``x``.
cost_fn: a CostFn function between two points in dimension d.
batch_size: When ``None``, the cost matrix corresponding to that point cloud
is computed, stored and later re-used at each application of
:meth:`apply_lse_kernel`. When ``batch_size`` is a positive integer,
computations are done in an online fashion, namely the cost matrix is
recomputed at each call of the :meth:`apply_lse_kernel` step,
``batch_size`` lines at a time, used on a vector and discarded.
The online computation is particularly useful for big point clouds
whose cost matrix does not fit in memory.
scale_cost: option to rescale the cost matrix. Implemented scalings are
'median', 'mean', 'max_cost', 'max_norm' and 'max_bound'.
Alternatively, a float factor can be given to rescale the cost such
that ``cost_matrix /= scale_cost``.
kwargs: keyword arguments for :class:`~ott.geometry.geometry.Geometry`.
"""
def __init__(
self,
x: jnp.ndarray,
y: Optional[jnp.ndarray] = None,
cost_fn: Optional[costs.CostFn] = None,
batch_size: Optional[int] = None,
scale_cost: Union[int, float, Literal["mean", "max_norm", "max_bound",
"max_cost", "median"]] = 1.0,
**kwargs: Any
):
super().__init__(**kwargs)
self.x = x
self.y = self.x if y is None else y
self.cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn
self._axis_norm = 0 if callable(self.cost_fn.norm) else None
if batch_size is not None:
assert batch_size > 0, f"`batch_size={batch_size}` must be positive."
self._batch_size = batch_size
self._scale_cost = scale_cost
@property
def _norm_x(self) -> Union[float, jnp.ndarray]:
if self._axis_norm == 0:
return self.cost_fn.norm(self.x)
return 0.0
@property
def _norm_y(self) -> Union[float, jnp.ndarray]:
if self._axis_norm == 0:
return self.cost_fn.norm(self.y)
return 0.0
@property
def can_LRC(self): # noqa: D102
return self.is_squared_euclidean and self._check_LRC_dim
@property
def _check_LRC_dim(self):
(n, m), d = self.shape, self.x.shape[1]
return n * m > (n + m) * d
@property
def cost_matrix(self) -> Optional[jnp.ndarray]: # noqa: D102
if self.is_online:
return None
cost_matrix = self._compute_cost_matrix()
return cost_matrix * self.inv_scale_cost
@property
def kernel_matrix(self) -> Optional[jnp.ndarray]: # noqa: D102
if self.is_online:
return None
return jnp.exp(-self.cost_matrix / self.epsilon)
@property
def shape(self) -> Tuple[int, int]: # noqa: D102
# in the process of flattening/unflattening in vmap, `__init__`
# can be called with dummy objects
# we optionally access `shape` in order to get the batch size
if self.x is None or self.y is None:
return 0, 0
return self.x.shape[0], self.y.shape[0]
@property
def is_symmetric(self) -> bool: # noqa: D102
return self.y is None or (
jnp.all(self.x.shape == self.y.shape) and jnp.all(self.x == self.y)
)
@property
def is_squared_euclidean(self) -> bool: # noqa: D102
return isinstance(self.cost_fn, costs.SqEuclidean)
@property
def is_online(self) -> bool:
"""Whether the cost/kernel is computed on-the-fly."""
return self.batch_size is not None
@property
def cost_rank(self) -> int: # noqa: D102
return self.x.shape[1]
@property
def inv_scale_cost(self) -> float: # noqa: D102
if isinstance(self._scale_cost, (int, float, jax.Array)):
return 1.0 / self._scale_cost
self = self._masked_geom()
if self._scale_cost == "max_cost":
if self.is_online:
return 1.0 / self._compute_summary_online(self._scale_cost)
return 1.0 / jnp.max(self._compute_cost_matrix())
if self._scale_cost == "mean":
if self.is_online:
return 1.0 / self._compute_summary_online(self._scale_cost)
if self.shape[0] > 0:
geom = self._masked_geom(mask_value=jnp.nan)._compute_cost_matrix()
return 1.0 / jnp.nanmean(geom)
return 1.0
if self._scale_cost == "median":
if not self.is_online:
geom = self._masked_geom(mask_value=jnp.nan)
return 1.0 / jnp.nanmedian(geom._compute_cost_matrix())
raise NotImplementedError(
"Using the median as scaling factor for "
"the cost matrix with the online mode is not implemented."
)
if self._scale_cost == "max_norm":
if self.cost_fn.norm is not None:
return 1.0 / jnp.maximum(self._norm_x.max(), self._norm_y.max())
return 1.0
if self._scale_cost == "max_bound":
if self.is_squared_euclidean:
x_argmax = jnp.argmax(self._norm_x)
y_argmax = jnp.argmax(self._norm_y)
max_bound = (
self._norm_x[x_argmax] + self._norm_y[y_argmax] +
2 * jnp.sqrt(self._norm_x[x_argmax] * self._norm_y[y_argmax])
)
return 1.0 / max_bound
raise NotImplementedError(
"Using max_bound as scaling factor for "
"the cost matrix when the cost is not squared euclidean "
"is not implemented."
)
raise ValueError(f"Scaling {self._scale_cost} not implemented.")
def _compute_cost_matrix(self) -> jnp.ndarray:
cost_matrix = self.cost_fn.all_pairs_pairwise(self.x, self.y)
if self._axis_norm is not None:
cost_matrix += self._norm_x[:, jnp.newaxis] + self._norm_y[jnp.newaxis, :]
return cost_matrix
[docs]
def apply_lse_kernel( # noqa: D102
self,
f: jnp.ndarray,
g: jnp.ndarray,
eps: float,
vec: Optional[jnp.ndarray] = None,
axis: int = 0
) -> jnp.ndarray:
def body0(carry, i: int):
f, g, eps, vec = carry
y, g_ = self._leading_slice(self.y, i), self._leading_slice(g, i)
norm_y = self._norm_y if self._axis_norm is None else self._leading_slice(
self._norm_y, i
)
h_res, h_sgn = app(
self.x, y, self._norm_x, norm_y, f, g_, eps, vec, cost_fn,
self.inv_scale_cost
)
return carry, (h_res, h_sgn)
def body1(carry, i: int):
f, g, eps, vec = carry
x, f_ = self._leading_slice(self.x, i), self._leading_slice(f, i)
norm_x = self._norm_x if self._axis_norm is None else self._leading_slice(
self._norm_x, i
)
h_res, h_sgn = app(
self.y, x, self._norm_y, norm_x, g, f_, eps, vec, cost_fn,
self.inv_scale_cost
)
return carry, (h_res, h_sgn)
def rest(i: int):
if axis == 0:
norm_y = self._norm_y if self._axis_norm is None else self._norm_y[i:]
return app(
self.x, self.y[i:], self._norm_x, norm_y, f, g[i:], eps, vec,
cost_fn, self.inv_scale_cost
)
norm_x = self._norm_x if self._axis_norm is None else self._norm_x[i:]
return app(
self.y, self.x[i:], self._norm_y, norm_x, g, f[i:], eps, vec, cost_fn,
self.inv_scale_cost
)
if not self.is_online:
return super().apply_lse_kernel(f, g, eps, vec, axis)
app = jax.vmap(
_apply_lse_kernel_xy,
in_axes=[
None, 0, None, self._axis_norm, None, 0, None, None, None, None
]
)
if axis == 0:
fun, cost_fn = body0, self.cost_fn.pairwise
v, n = g, self._y_nsplit
elif axis == 1:
fun, cost_fn = body1, lambda y, x: self.cost_fn.pairwise(x, y)
v, n = f, self._x_nsplit
else:
raise ValueError(axis)
_, (h_res, h_sign) = jax.lax.scan(
fun, init=(f, g, eps, vec), xs=jnp.arange(n)
)
h_res, h_sign = jnp.concatenate(h_res), jnp.concatenate(h_sign)
h_res_rest, h_sign_rest = rest(n * self.batch_size)
h_res = jnp.concatenate([h_res, h_res_rest])
h_sign = jnp.concatenate([h_sign, h_sign_rest])
return eps * h_res - jnp.where(jnp.isfinite(v), v, 0), h_sign
[docs]
def apply_kernel( # noqa: D102
self,
scaling: jnp.ndarray,
eps: Optional[float] = None,
axis: int = 0
) -> jnp.ndarray:
if eps is None:
eps = self.epsilon
if not self.is_online:
return super().apply_kernel(scaling, eps, axis)
# TODO(michalk8): batch this properly
app = jax.vmap(
_apply_kernel_xy,
in_axes=[None, 0, None, self._axis_norm, None, None, None, None]
)
if axis == 0:
return app(
self.x, self.y, self._norm_x, self._norm_y, scaling, eps,
self.cost_fn.pairwise, self.inv_scale_cost
)
# for non-symmetric costs
cost_fn = lambda y, x: self.cost_fn.pairwise(x, y)
return app(
self.y, self.x, self._norm_y, self._norm_x, scaling, eps, cost_fn,
self.inv_scale_cost
)
[docs]
def transport_from_potentials( # noqa: D102
self, f: jnp.ndarray, g: jnp.ndarray
) -> jnp.ndarray:
if not self.is_online:
return super().transport_from_potentials(f, g)
in_axes = [None, 0, None, self._axis_norm, None, 0, None, None, None]
transport = jax.vmap(_transport_from_potentials_xy, in_axes=in_axes)
cost_fn = lambda y, x: self.cost_fn.pairwise(x, y)
return transport(
self.y, self.x, self._norm_y, self._norm_x, g, f, self.epsilon, cost_fn,
self.inv_scale_cost
)
[docs]
def transport_from_scalings( # noqa: D102
self, u: jnp.ndarray, v: jnp.ndarray
) -> jnp.ndarray:
if not self.is_online:
return super().transport_from_scalings(u, v)
in_axes = [None, 0, None, self._axis_norm, None, 0, None, None, None]
transport = jax.vmap(_transport_from_scalings_xy, in_axes=in_axes)
cost_fn = lambda y, x: self.cost_fn.pairwise(x, y)
return transport(
self.y, self.x, self._norm_y, self._norm_x, v, u, self.epsilon, cost_fn,
self.inv_scale_cost
)
[docs]
def apply_cost(
self,
arr: jnp.ndarray,
axis: int = 0,
fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None,
is_linear: bool = False,
) -> jnp.ndarray:
"""Apply cost matrix to array (vector or matrix).
This function applies the geometry's cost matrix, to perform either
output = C arr (if axis=1)
output = C' arr (if axis=0)
where C is [num_a, num_b] matrix resulting from the (optional) elementwise
application of fn to each entry of the :attr:`cost_matrix`.
Args:
arr: jnp.ndarray [num_a or num_b, batch], vector that will be multiplied
by the cost matrix.
axis: standard cost matrix if axis=1, transpose if 0.
fn: function optionally applied to cost matrix element-wise, before the
apply.
is_linear: Whether ``fn`` is a linear function.
If true and :attr:`is_squared_euclidean` is ``True``, efficient
implementation is used. See :func:`ott.geometry.geometry.is_linear`
for a heuristic to help determine if a function is linear.
Returns:
A jnp.ndarray, [num_b, batch] if axis=0 or [num_a, batch] if axis=1
"""
# switch to efficient computation for the squared euclidean case.
if self.is_squared_euclidean and (fn is None or is_linear):
return self.vec_apply_cost(arr, axis, fn=fn)
return self._apply_cost(arr, axis, fn=fn)
def _apply_cost(
self, arr: jnp.ndarray, axis: int = 0, fn=None
) -> jnp.ndarray:
"""See :meth:`apply_cost`."""
if not self.is_online:
return super().apply_cost(arr, axis, fn)
# TODO(michalk8): batch this properly
app = jax.vmap(
_apply_cost_xy,
in_axes=[None, 0, None, self._axis_norm, None, None, None, None]
)
if arr.ndim == 1:
arr = arr.reshape(-1, 1)
if axis == 0:
return app(
self.x, self.y, self._norm_x, self._norm_y, arr,
self.cost_fn.pairwise, self.inv_scale_cost, fn
)
cost_fn = lambda y, x: self.cost_fn.pairwise(x, y)
return app(
self.y, self.x, self._norm_y, self._norm_x, arr, cost_fn,
self.inv_scale_cost, fn
)
[docs]
def vec_apply_cost(
self,
arr: jnp.ndarray,
axis: int = 0,
fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None
) -> jnp.ndarray:
"""Apply the geometry's cost matrix in a vectorized way.
This function can be used when the cost matrix is squared euclidean
and ``fn`` is a linear function.
Args:
arr: jnp.ndarray [num_a or num_b, p], vector that will be multiplied
by the cost matrix.
axis: standard cost matrix if axis=1, transport if 0.
fn: function optionally applied to cost matrix element-wise, before the
application.
Returns:
A jnp.ndarray, [num_b, p] if axis=0 or [num_a, p] if axis=1
"""
assert self.is_squared_euclidean, "Cost matrix is not a squared Euclidean."
rank = arr.ndim
x, y = (self.x, self.y) if axis == 0 else (self.y, self.x)
nx, ny = jnp.asarray(self._norm_x), jnp.asarray(self._norm_y)
nx, ny = (nx, ny) if axis == 0 else (ny, nx)
applied_cost = jnp.dot(nx, arr).reshape(1, -1)
applied_cost += ny.reshape(-1, 1) * jnp.sum(arr, axis=0).reshape(1, -1)
cross_term = -2.0 * jnp.dot(y, jnp.dot(x.T, arr))
applied_cost += cross_term[:, None] if rank == 1 else cross_term
if fn is not None:
applied_cost = fn(applied_cost)
return self.inv_scale_cost * applied_cost
def _leading_slice(self, t: jnp.ndarray, i: int) -> jnp.ndarray:
start_indices = [i * self.batch_size] + (t.ndim - 1) * [0]
slice_sizes = [self.batch_size] + list(t.shape[1:])
return jax.lax.dynamic_slice(t, start_indices, slice_sizes)
def _compute_summary_online(
self, summary: Literal["mean", "max_cost"]
) -> float:
"""Compute mean or max of cost matrix online, i.e. without instantiating it.
Args:
summary: can be 'mean' or 'max_cost'.
Returns:
summary statistics
"""
scale_cost = 1.0
def body0(vec: jnp.ndarray, i: int):
y = self._leading_slice(self.y, i)
norm_y = self._norm_y if self._axis_norm is None else self._leading_slice(
self._norm_y, i
)
h_res = app(self.x, y, self._norm_x, norm_y, vec, cost_fn, scale_cost)
return vec, h_res
def body1(vec: jnp.ndarray, i: int):
x = self._leading_slice(self.x, i)
norm_x = self._norm_x if self._axis_norm is None else self._leading_slice(
self._norm_x, i
)
h_res = app(self.y, x, self._norm_y, norm_x, vec, cost_fn, scale_cost)
return vec, h_res
def rest(i: int) -> jnp.ndarray:
if batch_for_y:
norm_y = self._norm_y if self._axis_norm is None else self._norm_y[i:]
return app(
self.x, self.y[i:], self._norm_x, norm_y, vec, cost_fn, scale_cost
)
norm_x = self._norm_x if self._axis_norm is None else self._norm_x[i:]
return app(
self.y, self.x[i:], self._norm_y, norm_x, vec, cost_fn, scale_cost
)
if summary == "mean":
fn = _apply_cost_xy
elif summary == "max_cost":
fn = _apply_max_xy
else:
raise ValueError(
f"Scaling method {summary} does not exist for online mode."
)
app = jax.vmap(
fn, in_axes=[None, 0, None, self._axis_norm, None, None, None]
)
batch_for_y = self.shape[0] < self.shape[1]
if batch_for_y:
fun, cost_fn = body0, self.cost_fn.pairwise
n = self._y_nsplit
vec, other = self._n_normed_ones, self._m_normed_ones
else:
fun, cost_fn = body1, lambda y, x: self.cost_fn.pairwise(x, y)
n = self._x_nsplit
vec, other = self._m_normed_ones, self._n_normed_ones
_, val = jax.lax.scan(fun, init=vec, xs=jnp.arange(n))
val = jnp.concatenate(val).squeeze()
val_rest = rest(n * self.batch_size)
val_res = jnp.concatenate([val, val_rest])
if summary == "mean":
return jnp.sum(val_res * other)
if summary == "max_cost":
return jnp.max(val_res)
raise ValueError(
f"Scaling method {summary} does not exist for online mode."
)
[docs]
def barycenter(self, weights: jnp.ndarray) -> jnp.ndarray:
"""Compute barycenter of points in self.x using weights."""
return self.cost_fn.barycenter(self.x, weights)[0]
[docs]
@classmethod
def prepare_divergences(
cls,
x: jnp.ndarray,
y: jnp.ndarray,
static_b: bool = False,
src_mask: Optional[jnp.ndarray] = None,
tgt_mask: Optional[jnp.ndarray] = None,
**kwargs: Any
) -> Tuple["PointCloud", ...]:
"""Instantiate the geometries used for a divergence computation."""
couples = [(x, y), (x, x)]
masks = [(src_mask, tgt_mask), (src_mask, src_mask)]
if not static_b:
couples += [(y, y)]
masks += [(tgt_mask, tgt_mask)]
return tuple(
cls(x, y, src_mask=x_mask, tgt_mask=y_mask, **kwargs)
for ((x, y), (x_mask, y_mask)) in zip(couples, masks)
)
def tree_flatten(self): # noqa: D102
return (
self.x,
self.y,
self._src_mask,
self._tgt_mask,
self._epsilon_init,
self.cost_fn,
), {
"batch_size": self._batch_size,
"scale_cost": self._scale_cost
}
@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
x, y, src_mask, tgt_mask, epsilon, cost_fn = children
return cls(
x,
y,
cost_fn=cost_fn,
src_mask=src_mask,
tgt_mask=tgt_mask,
epsilon=epsilon,
**aux_data
)
def _cosine_to_sqeucl(self) -> "PointCloud":
assert isinstance(self.cost_fn, costs.Cosine), type(self.cost_fn)
(x, y, *args, _), aux_data = self.tree_flatten()
x = x / jnp.linalg.norm(x, axis=-1, keepdims=True)
y = y / jnp.linalg.norm(y, axis=-1, keepdims=True)
# TODO(michalk8): find a better way
aux_data["scale_cost"] = 2.0 / self.inv_scale_cost
cost_fn = costs.SqEuclidean()
return type(self).tree_unflatten(aux_data, [x, y] + args + [cost_fn])
[docs]
def to_LRCGeometry(
self,
scale: float = 1.0,
**kwargs: Any,
) -> Union[low_rank.LRCGeometry, "PointCloud"]:
r"""Convert point cloud to low-rank geometry.
Args:
scale: Value used to rescale the factors of the low-rank geometry.
Useful when this geometry is used in the linear term of fused GW.
kwargs: Keyword arguments, such as ``rank``, to
:meth:`~ott.geometry.geometry.Geometry.to_LRCGeometry` used when
the point cloud does not have squared Euclidean cost.
Returns:
Returns the unmodified point cloud if :math:`n m \ge (n + m) d`, where
:math:`n, m` is the shape and :math:`d` is the dimension of the point
cloud with squared Euclidean cost.
Otherwise, returns the re-scaled low-rank geometry.
"""
if self.is_squared_euclidean:
if self._check_LRC_dim:
return self._sqeucl_to_lr(scale)
# we don't update the `scale_factor` because in GW, the linear cost
# is first materialized and then scaled by `fused_penalty` afterwards
return self
return super().to_LRCGeometry(scale=scale, **kwargs)
def _sqeucl_to_lr(self, scale: float = 1.0) -> low_rank.LRCGeometry:
assert self.is_squared_euclidean, "Geometry must be squared Euclidean."
n, m = self.shape
nx = jnp.sum(self.x ** 2, axis=1, keepdims=True)
ny = jnp.sum(self.y ** 2, axis=1, keepdims=True)
cost_1 = jnp.concatenate((nx, jnp.ones((n, 1)), -jnp.sqrt(2.0) * self.x),
axis=1)
cost_2 = jnp.concatenate((jnp.ones((m, 1)), ny, jnp.sqrt(2.0) * self.y),
axis=1)
return low_rank.LRCGeometry(
cost_1=cost_1,
cost_2=cost_2,
scale_factor=scale,
epsilon=self._epsilon_init,
relative_epsilon=self._relative_epsilon,
scale_cost=self._scale_cost,
src_mask=self.src_mask,
tgt_mask=self.tgt_mask,
)
[docs]
def subset( # noqa: D102
self, src_ixs: Optional[jnp.ndarray], tgt_ixs: Optional[jnp.ndarray],
**kwargs: Any
) -> "PointCloud":
def subset_fn(
arr: Optional[jnp.ndarray],
ixs: Optional[jnp.ndarray],
) -> jnp.ndarray:
return arr if arr is None or ixs is None else arr[ixs, ...]
return self._mask_subset_helper(
src_ixs, tgt_ixs, fn=subset_fn, propagate_mask=True, **kwargs
)
[docs]
def mask( # noqa: D102
self,
src_mask: Optional[jnp.ndarray],
tgt_mask: Optional[jnp.ndarray],
mask_value: float = 0.0,
) -> "PointCloud":
def mask_fn(
arr: Optional[jnp.ndarray],
mask: Optional[jnp.ndarray],
) -> Optional[jnp.ndarray]:
if arr is None or mask is None:
return arr
return jnp.where(mask[:, None], arr, mask_value)
src_mask = self._normalize_mask(src_mask, self.shape[0])
tgt_mask = self._normalize_mask(tgt_mask, self.shape[1])
return self._mask_subset_helper(
src_mask, tgt_mask, fn=mask_fn, propagate_mask=False
)
def _mask_subset_helper(
self,
src_ixs: Optional[jnp.ndarray],
tgt_ixs: Optional[jnp.ndarray],
*,
fn: Callable[[Optional[jnp.ndarray], Optional[jnp.ndarray]],
Optional[jnp.ndarray]],
propagate_mask: bool,
**kwargs: Any,
) -> "PointCloud":
(x, y, src_mask, tgt_mask, *children), aux_data = self.tree_flatten()
x = fn(x, src_ixs)
y = fn(y, tgt_ixs)
if propagate_mask:
src_mask = self._normalize_mask(src_mask, self.shape[0])
tgt_mask = self._normalize_mask(tgt_mask, self.shape[1])
src_mask = fn(src_mask, src_ixs)
tgt_mask = fn(tgt_mask, tgt_ixs)
aux_data = {**aux_data, **kwargs}
return type(self).tree_unflatten(
aux_data, [x, y, src_mask, tgt_mask] + children
)
@property
def dtype(self) -> jnp.dtype: # noqa: D102
return self.x.dtype
@property
def batch_size(self) -> Optional[int]:
"""Batch size for online mode."""
if self._batch_size is None:
return None
n, m = self.shape
return min(n, m, self._batch_size)
@property
def _x_nsplit(self) -> Optional[int]:
if self.batch_size is None:
return None
n, _ = self.shape
return int(math.floor(n / self.batch_size))
@property
def _y_nsplit(self) -> Optional[int]:
if self.batch_size is None:
return None
_, m = self.shape
return int(math.floor(m / self.batch_size))
def _apply_lse_kernel_xy(
x, y, norm_x, norm_y, f, g, eps, vec, cost_fn, scale_cost
):
c = _cost(x, y, norm_x, norm_y, cost_fn, scale_cost)
return mu.logsumexp((f + g - c) / eps, b=vec, return_sign=True, axis=-1)
def _transport_from_potentials_xy(
x, y, norm_x, norm_y, f, g, eps, cost_fn, scale_cost
):
c = _cost(x, y, norm_x, norm_y, cost_fn, scale_cost)
return jnp.exp((f + g - c) / eps)
def _apply_kernel_xy(x, y, norm_x, norm_y, vec, eps, cost_fn, scale_cost):
c = _cost(x, y, norm_x, norm_y, cost_fn, scale_cost)
return jnp.dot(jnp.exp(-c / eps), vec)
def _transport_from_scalings_xy(
x, y, norm_x, norm_y, u, v, eps, cost_fn, scale_cost
):
c = _cost(x, y, norm_x, norm_y, cost_fn, scale_cost)
return jnp.exp(-c * scale_cost / eps) * u * v
def _cost(x, y, norm_x, norm_y, cost_fn, scale_cost):
one_line_pairwise = jax.vmap(cost_fn, in_axes=[0, None])
cost = norm_x + norm_y + one_line_pairwise(x, y)
return cost * scale_cost
def _apply_cost_xy(x, y, norm_x, norm_y, vec, cost_fn, scale_cost, fn=None):
"""Apply [num_b, num_a] fn(cost) matrix (or transpose) to vector.
Applies [num_b, num_a] ([num_a, num_b] if axis=1 from `apply_cost`)
fn(cost) matrix (or transpose) to vector.
Args:
x: jnp.ndarray [num_a, d], first pointcloud
y: jnp.ndarray [num_b, d], second pointcloud
norm_x: jnp.ndarray [num_a,], (squared) norm as defined in by cost_fn
norm_y: jnp.ndarray [num_b,], (squared) norm as defined in by cost_fn
vec: jnp.ndarray [num_a,] ([num_b,] if axis=1 from `apply_cost`) vector
cost_fn: a CostFn function between two points in dimension d.
scale_cost: scaling factor of the cost matrix.
fn: function optionally applied to cost matrix element-wise, before the
apply.
Returns:
A jnp.ndarray corresponding to cost x vector
"""
c = _cost(x, y, norm_x, norm_y, cost_fn, scale_cost)
return jnp.dot(c, vec) if fn is None else jnp.dot(fn(c), vec)
def _apply_max_xy(x, y, norm_x, norm_y, vec, cost_fn, scale_cost):
del vec
c = _cost(x, y, norm_x, norm_y, cost_fn, scale_cost)
return jnp.max(jnp.abs(c))