ott.math.utils.norm

Contents

ott.math.utils.norm#

ott.math.utils.norm(x, ord=None, axis=None, keepdims=False) = <jax._src.custom_derivatives.custom_jvp object>[source]#

Computes order ord norm of vector, using jnp.linalg in forward pass.

Evaluations of distances between a vector and itself using translation invariant costs, typically norms, result in functions of the form lambda x : jnp.linalg.norm(x-x). Such functions output NaN gradients, because they involve computing the derivative of a negative exponent of 0 (e.g. when differentiating the Euclidean norm, one gets a 0-denominator in the expression, see e.g. google/jax#6484 for context).

While this makes sense mathematically, in the context of optimal transport such distances between a point and itself can be safely ignored when they contribute to an OT cost (when, for instance, computing Sinkhorn divergences, involving computing the OT cost of a point cloud with itself).

To avoid such NaN values, this custom norm implementation uses the double-where trick, to avoid having branches that output any NaN, and safely output a 0 instead.

Parameters:
  • x (Array) – Input array. If axis is None, x must be 1-D or 2-D, unless ord is None. If both axis and ord are None, the 2-norm of x.ravel will be returned.

  • ord (Union[int, str, None]) – {non-zero int, jnp.inf, -jnp.inf, ‘fro’, ‘nuc’}, Order of the norm. The default is None, which is equivalent to 2.0 for vectors.

  • axis (Union[None, Sequence[int], int]) – {None, int, 2-tuple of ints}, optional. If axis is an integer, it specifies the axis of x along which to compute the vector norms. If axis is a 2-tuple, it specifies the axes that hold 2-D matrices, and the matrix norms of these matrices are computed. If axis is None then either a vector norm (when x is 1-D) or a matrix norm (when x is 2-D) is returned. The default is None.

  • keepdims (bool) – If set to True, the axes which are normed over are left in the result as dimensions with size one. With this option the result will broadcast correctly against the original x.

Return type:

Array

Returns:

float or ndarray, Norm of the matrix or vector(s).