# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Literal, Optional, Tuple, Union
import jax.numpy as jnp
import jax.tree_util as jtu
from ott import utils
from ott.geometry import costs, geometry, low_rank
from ott.math import utils as mu
__all__ = ["PointCloud"]
[docs]
@jtu.register_pytree_node_class
class PointCloud(geometry.Geometry):
"""Defines geometry for 2 point clouds (possibly 1 vs itself).
When the number of points is large, setting the :attr:`batch_size` flag
implies that cost and kernel matrices used to update potentials or scalings
will be recomputed on the fly, rather than stored in memory.
Args:
x: Array of shape ``[n, d]``.
y: Array of shape ``[m, d]``. If :obj:`None`, use ``x``.
cost_fn: Cost function between two points in dimension :math:`d`.
batch_size: If :obj:`None`, the cost matrix corresponding to that
point cloud is computed, stored and later re-used at each application of
:meth:`apply_lse_kernel`. When ``batch_size`` is a positive integer,
computations are done in an online fashion, namely the cost matrix is
recomputed at each call of the :meth:`apply_lse_kernel` step,
``batch_size`` lines at a time, used on a vector and discarded.
The online computation is particularly useful for big point clouds
whose cost matrix does not fit in memory.
scale_cost: option to rescale the cost matrix. Implemented scalings are
'median', 'mean', 'max_cost', 'max_norm' and 'max_bound'.
Alternatively, a float factor can be given to rescale the cost such
that ``cost_matrix /= scale_cost``.
kwargs: keyword arguments for :class:`~ott.geometry.geometry.Geometry`.
"""
def __init__(
self,
x: jnp.ndarray,
y: Optional[jnp.ndarray] = None,
cost_fn: Optional[costs.CostFn] = None,
batch_size: Optional[int] = None,
scale_cost: Union[float, Literal["mean", "max_norm", "max_bound",
"max_cost", "median"]] = 1.0,
**kwargs: Any,
):
super().__init__(**kwargs)
self.x = x
self.y = self.x if y is None else y
self.cost_fn = costs.SqEuclidean() if cost_fn is None else cost_fn
if batch_size is not None:
assert batch_size > 0, f"`batch_size={batch_size}` must be positive."
self._batch_size = batch_size
self._scale_cost = scale_cost
[docs]
def apply_lse_kernel( # noqa: D102
self,
f: jnp.ndarray,
g: jnp.ndarray,
eps: float,
vec: Optional[jnp.ndarray] = None,
axis: int = 0
) -> Tuple[jnp.ndarray, jnp.ndarray]:
if not self.is_online:
return super().apply_lse_kernel(f, g, eps, vec, axis)
def apply(x: jnp.ndarray, y: jnp.ndarray, f: jnp.ndarray,
g: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
x, y = jnp.atleast_2d(x), jnp.atleast_2d(y)
cost = self.cost_fn.all_pairs(x, y) * inv_scale_cost
cost = cost.squeeze(1 - axis)
# axis=-1
res, sgn = mu.logsumexp((f + g - cost) / eps, b=vec, return_sign=True)
return eps * res, sgn
inv_scale_cost = self.inv_scale_cost
in_axes = (None, 0, None, 0) if axis == 0 else (0, None, 0, None)
batched_apply = utils.batched_vmap(
apply,
batch_size=self.batch_size,
in_axes=in_axes,
)
w_res, w_sgn = batched_apply(self.x, self.y, f, g)
remove = f if axis == 1 else g
return w_res - jnp.where(jnp.isfinite(remove), remove, 0), w_sgn
[docs]
def apply_kernel( # noqa: D102
self,
vec: jnp.ndarray,
eps: Optional[float] = None,
axis: int = 0
) -> jnp.ndarray:
if eps is None:
eps = self.epsilon
if not self.is_online:
return super().apply_kernel(vec, eps, axis)
def apply(x: jnp.ndarray, y: jnp.ndarray, vec: jnp.ndarray) -> jnp.ndarray:
x, y = jnp.atleast_2d(x), jnp.atleast_2d(y)
cost = self.cost_fn.all_pairs(x, y) * inv_scale_cost
cost = cost.squeeze(1 - axis)
return jnp.dot(jnp.exp(-cost / eps), vec)
inv_scale_cost = self.inv_scale_cost
in_axes = (None, 0, None) if axis == 0 else (0, None, None)
batched_apply = utils.batched_vmap(
apply, batch_size=self.batch_size, in_axes=in_axes
)
return batched_apply(self.x, self.y, vec)
def _apply_cost_to_vec(
self,
vec: jnp.ndarray,
axis: int = 0,
fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None,
is_linear: bool = False,
scale_cost: Optional[float] = None,
) -> jnp.ndarray:
def apply(x: jnp.ndarray, y: jnp.ndarray, arr: jnp.ndarray) -> jnp.ndarray:
x, y = jnp.atleast_2d(x), jnp.atleast_2d(y)
cost = self.cost_fn.all_pairs(x, y) * scale_cost
cost = cost.squeeze(1 - axis)
if fn is not None:
cost = fn(cost)
return jnp.dot(cost, arr)
# when computing the online properties, this is set to 1.0
if scale_cost is None:
scale_cost = self.inv_scale_cost
# switch to an efficient computation for the squared Euclidean case
if self.is_squared_euclidean and (fn is None or is_linear):
return self._apply_sqeucl_cost(
vec,
scale_cost,
axis=axis,
fn=fn,
)
# materialize the cost
if not self.is_online:
return super()._apply_cost_to_vec(
vec, axis=axis, fn=fn, is_linear=is_linear
)
in_axes = (None, 0, None) if axis == 0 else (0, None, None)
batched_apply = utils.batched_vmap(
apply, batch_size=self.batch_size, in_axes=in_axes
)
return batched_apply(self.x, self.y, vec)
def _apply_sqeucl_cost(
self,
vec: jnp.ndarray,
scale_cost: float,
axis: int = 0,
fn: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = None,
) -> jnp.ndarray:
assert vec.ndim == 1, vec.shape
assert self.is_squared_euclidean, "Cost matrix is not a squared Euclidean."
x, y = (self.x, self.y) if axis == 0 else (self.y, self.x)
nx, ny = self.cost_fn.norm(x), self.cost_fn.norm(y)
applied_cost = jnp.dot(nx, vec) + ny * jnp.sum(vec, axis=0)
applied_cost = applied_cost - 2.0 * jnp.dot(y, jnp.dot(x.T, vec))
if fn is not None:
applied_cost = fn(applied_cost)
return scale_cost * applied_cost
def _compute_summary_online(
self, summary: Literal["mean", "max_cost"]
) -> jnp.ndarray:
"""Compute mean or max of cost matrix online, i.e. without instantiating it.
Args:
summary: can be 'mean' or 'max_cost'.
Returns:
summary statistics
"""
def compute_max(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
x, y = jnp.atleast_2d(x), jnp.atleast_2d(y)
cost = self.cost_fn.all_pairs(x, y)
return jnp.max(jnp.abs(cost))
if summary == "mean":
n, m = self.shape
a = jnp.full((n,), fill_value=1.0 / n)
b = jnp.full((m,), fill_value=1.0 / m)
return jnp.sum(self._apply_cost_to_vec(a, scale_cost=1.0) * b)
if summary == "max_cost":
fn = utils.batched_vmap(
compute_max, batch_size=self.batch_size, in_axes=[0, None]
)
return jnp.max(fn(self.x, self.y))
raise ValueError(
f"Scaling method {summary} does not exist for online mode."
)
[docs]
def barycenter(self, weights: jnp.ndarray) -> jnp.ndarray:
"""Compute barycenter of points in self.x using weights."""
return self.cost_fn.barycenter(self.x, weights)[0]
[docs]
@classmethod
def prepare_divergences(
cls,
x: jnp.ndarray,
y: jnp.ndarray,
static_b: bool = False,
**kwargs: Any
) -> Tuple["PointCloud", ...]:
"""Instantiate the geometries used for a divergence computation."""
couples = [(x, y), (x, x)]
if not static_b:
couples += [(y, y)]
return tuple(cls(x, y, **kwargs) for (x, y) in couples)
def tree_flatten(self): # noqa: D102
return (
self.x,
self.y,
self._epsilon_init,
self.cost_fn,
), {
"batch_size": self._batch_size,
"scale_cost": self._scale_cost,
"relative_epsilon": self._relative_epsilon,
}
@classmethod
def tree_unflatten(cls, aux_data, children): # noqa: D102
x, y, epsilon, cost_fn = children
return cls(x, y, cost_fn=cost_fn, epsilon=epsilon, **aux_data)
def _cosine_to_sqeucl(self) -> "PointCloud":
assert isinstance(self.cost_fn, costs.Cosine), type(self.cost_fn)
(x, y, *args, _), aux_data = self.tree_flatten()
x = x / jnp.linalg.norm(x, axis=-1, keepdims=True)
y = y / jnp.linalg.norm(y, axis=-1, keepdims=True)
# TODO(michalk8): find a better way
aux_data["scale_cost"] = 2.0 / self.inv_scale_cost
cost_fn = costs.SqEuclidean()
return type(self).tree_unflatten(aux_data, [x, y] + args + [cost_fn])
[docs]
def to_LRCGeometry(
self,
scale: float = 1.0,
**kwargs: Any,
) -> Union[low_rank.LRCGeometry, "PointCloud"]:
r"""Convert point cloud to low-rank geometry.
Args:
scale: Value used to rescale the factors of the low-rank geometry.
Useful when this geometry is used in the linear term of fused GW.
kwargs: Keyword arguments, such as ``rank``, to
:meth:`~ott.geometry.geometry.Geometry.to_LRCGeometry` used when
the point cloud does not have squared Euclidean cost.
Returns:
Returns the unmodified point cloud if :math:`n m \ge (n + m) d`, where
:math:`n, m` is the shape and :math:`d` is the dimension of the point
cloud with squared Euclidean cost.
Otherwise, returns the re-scaled low-rank geometry.
"""
if self.is_squared_euclidean:
if self._check_LRC_dim:
return self._sqeucl_to_lr(scale)
# we don't update the `scale_factor` because in GW, the linear cost
# is first materialized and then scaled by `fused_penalty` afterwards
return self
return super().to_LRCGeometry(scale=scale, **kwargs)
def _sqeucl_to_lr(self, scale: float = 1.0) -> low_rank.LRCGeometry:
assert self.is_squared_euclidean, "Geometry must be squared Euclidean."
n, m = self.shape
nx = jnp.sum(self.x ** 2, axis=1, keepdims=True)
ny = jnp.sum(self.y ** 2, axis=1, keepdims=True)
cost_1 = jnp.concatenate(
(nx, jnp.ones((n, 1), dtype=self.dtype), -(2.0 ** 0.5) * self.x),
axis=1,
)
cost_2 = jnp.concatenate(
(jnp.ones((m, 1), dtype=self.dtype), ny, (2.0 ** 0.5) * self.y),
axis=1,
)
return low_rank.LRCGeometry(
cost_1=cost_1,
cost_2=cost_2,
scale_factor=scale,
epsilon=self._epsilon_init,
relative_epsilon=self._relative_epsilon,
scale_cost=self._scale_cost,
)
@property
def cost_matrix(self) -> Optional[jnp.ndarray]: # noqa: D102
return self.inv_scale_cost * self._unscaled_cost_matrix
@property
def _unscaled_cost_matrix(self) -> jnp.ndarray:
return self.cost_fn.all_pairs(self.x, self.y)
@property
def inv_scale_cost(self) -> jnp.ndarray: # noqa: D102
if self._scale_cost == "max_cost":
if self.is_online:
return 1.0 / self._compute_summary_online(self._scale_cost)
return 1.0 / jnp.max(self._unscaled_cost_matrix)
if self._scale_cost == "mean":
if self.is_online:
return 1.0 / self._compute_summary_online(self._scale_cost)
return 1.0 / jnp.mean(self._unscaled_cost_matrix)
if self._scale_cost == "median":
if not self.is_online:
return 1.0 / jnp.median(self._unscaled_cost_matrix)
raise NotImplementedError(
"Using the median as scaling factor for "
"the cost matrix with the online mode is not implemented."
)
if self._scale_cost == "max_norm":
norm_x = self.cost_fn.norm(self.x)
norm_y = self.cost_fn.norm(self.y)
return 1.0 / jnp.maximum(norm_x.max(), norm_y.max())
if self._scale_cost == "max_bound":
norm_x = self.cost_fn.norm(self.x)
norm_y = self.cost_fn.norm(self.y)
if self.is_squared_euclidean:
x_argmax = jnp.argmax(norm_x)
y_argmax = jnp.argmax(norm_y)
max_bound = (
norm_x[x_argmax] + norm_y[y_argmax] +
2 * jnp.sqrt(norm_x[x_argmax] * norm_y[y_argmax])
)
return 1.0 / max_bound
raise NotImplementedError(
"Using max_bound as scaling factor for "
"the cost matrix when the cost is not squared euclidean "
"is not implemented."
)
if utils.is_scalar(self._scale_cost):
return 1.0 / self._scale_cost
raise ValueError(f"Scaling {self._scale_cost} not implemented.")
[docs]
def subset( # noqa: D102
self,
row_ixs: Optional[jnp.ndarray] = None,
col_ixs: Optional[jnp.ndarray] = None,
) -> "PointCloud":
(x, y, *rest), aux_data = self.tree_flatten()
if row_ixs is not None:
x = x[jnp.atleast_1d(row_ixs)]
if col_ixs is not None:
y = y[jnp.atleast_1d(col_ixs)]
return type(self).tree_unflatten(aux_data, (x, y, *rest))
@property
def kernel_matrix(self) -> Optional[jnp.ndarray]: # noqa: D102
return jnp.exp(-self.cost_matrix / self.epsilon)
@property
def shape(self) -> Tuple[int, int]: # noqa: D102
return self.x.shape[0], self.y.shape[0]
@property
def dtype(self) -> jnp.dtype: # noqa: D102
return self.x.dtype
@property
def is_symmetric(self) -> bool: # noqa: D102
n, m = self.shape
return self.y is None or ((n == m) and jnp.all(self.x == self.y))
@property
def is_squared_euclidean(self) -> bool: # noqa: D102
return isinstance(self.cost_fn, costs.SqEuclidean)
@property
def can_LRC(self): # noqa: D102
return self.is_squared_euclidean and self._check_LRC_dim
@property
def _check_LRC_dim(self):
(n, m), d = self.shape, self.x.shape[1]
return n * m > (n + m) * d
@property
def cost_rank(self) -> int: # noqa: D102
return self.x.shape[1]
@property
def batch_size(self) -> Optional[int]:
"""Batch size for online mode."""
if self._batch_size is None:
return None
n, m = self.shape
return min(n, m, self._batch_size)
@property
def is_online(self) -> bool:
"""Whether the cost/kernel is computed on-the-fly."""
return self.batch_size is not None