# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import TYPE_CHECKING, Optional, Sequence, Tuple, Union
import jax
import jax.numpy as jnp
import jax.scipy as jsp
if TYPE_CHECKING:
from ott.geometry import costs
__all__ = [
"safe_log",
"norm",
"kl",
"gen_kl",
"gen_js",
"logsumexp",
"softmin",
"barycentric_projection",
"sort_and_argsort",
"lambertw",
]
def safe_log( # noqa: D103
x: jnp.ndarray,
*,
eps: Optional[float] = None
) -> jnp.ndarray:
if eps is None:
eps = jnp.finfo(x.dtype).tiny
return jnp.where(x > 0.0, jnp.log(x), jnp.log(eps))
[docs]
@functools.partial(jax.custom_jvp, nondiff_argnums=[1, 2, 3])
@functools.partial(jax.jit, static_argnames=("ord", "axis", "keepdims"))
def norm(
x: jnp.ndarray,
ord: Union[int, str, None] = None,
axis: Union[None, Sequence[int], int] = None,
keepdims: bool = False
) -> jnp.ndarray:
"""Computes order ord norm of vector, using `jnp.linalg` in forward pass.
Evaluations of distances between a vector and itself using translation
invariant costs, typically norms, result in functions of the form
``lambda x : jnp.linalg.norm(x-x)``. Such functions output `NaN` gradients,
because they involve computing the derivative of a negative exponent of 0
(e.g. when differentiating the Euclidean norm, one gets a 0-denominator in the
expression, see e.g. https://github.com/google/jax/issues/6484 for context).
While this makes sense mathematically, in the context of optimal transport
such distances between a point and itself can be safely ignored when they
contribute to an OT cost (when, for instance, computing Sinkhorn divergences,
involving computing the OT cost of a point cloud with itself).
To avoid such `NaN` values, this custom norm implementation uses the
double-where trick, to avoid having branches that output any `NaN`, and
safely output a 0 instead.
Args:
x: Input array. If `axis` is None, `x` must be 1-D or 2-D, unless `ord`
is None. If both `axis` and `ord` are None, the 2-norm of ``x.ravel``
will be returned.
ord: `{non-zero int, jnp.inf, -jnp.inf, 'fro', 'nuc'}`, Order of the norm.
The default is `None`, which is equivalent to `2.0` for vectors.
axis: `{None, int, 2-tuple of ints}`, optional. If `axis` is an integer, it
specifies the axis of `x` along which to compute the vector norms.
If `axis` is a 2-tuple, it specifies the axes that hold 2-D matrices, and
the matrix norms of these matrices are computed. If `axis` is None then
either a vector norm (when `x` is 1-D) or a matrix norm (when `x` is 2-D)
is returned. The default is None.
keepdims: If set to True, the axes which are normed over are left in the
result as dimensions with size one. With this option the result will
broadcast correctly against the original `x`.
Returns:
float or ndarray, Norm of the matrix or vector(s).
"""
return jnp.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)
@norm.defjvp
def norm_jvp(ord, axis, keepdims, primals, tangents):
"""Custom_jvp for norm, that returns 0.0 when evaluated at 0."""
x, = primals
x_is_zero = jnp.all(jnp.logical_not(x))
clean_x = jnp.where(x_is_zero, jnp.ones_like(x), x)
primals, tangents = jax.jvp(
functools.partial(jnp.linalg.norm, ord=ord, axis=axis, keepdims=keepdims),
(clean_x,), tangents
)
return primals, jnp.where(x_is_zero, 0.0, tangents)
# TODO(michalk8): add axis argument
def kl(p: jnp.ndarray, q: jnp.ndarray) -> float:
"""Kullback-Leibler divergence."""
return jnp.vdot(p, (safe_log(p) - safe_log(q)))
def gen_kl(p: jnp.ndarray, q: jnp.ndarray) -> float:
"""Generalized Kullback-Leibler divergence."""
return jnp.vdot(p, (safe_log(p) - safe_log(q))) + jnp.sum(q) - jnp.sum(p)
# TODO(michalk8): add axis argument
def gen_js(p: jnp.ndarray, q: jnp.ndarray, c: float = 0.5) -> float:
"""Jensen-Shannon divergence."""
return c * (gen_kl(p, q) + gen_kl(q, p))
[docs]
@functools.partial(jax.custom_jvp, nondiff_argnums=(1, 2, 4))
def logsumexp( # noqa: D103
mat, axis=None, keepdims=False, b=None, return_sign=False
):
return jax.scipy.special.logsumexp(
mat, axis=axis, keepdims=keepdims, b=b, return_sign=return_sign
)
@logsumexp.defjvp
def logsumexp_jvp(axis, keepdims, return_sign, primals, tangents):
"""Custom derivative rule for lse that does not blow up with -inf.
This logsumexp implementation uses the standard jax one in forward mode but
implements a custom rule to differentiate. Given the preference of jax for
jvp over vjp, and the fact that this is a simple linear rule, jvp is used.
This custom differentiation address issues when the output of lse is
-inf (which corresponds to the case where all inputs in a slice are -inf,
which happens typically when ``a`` or ``b`` weight vectors have zeros.)
Although both exp(lse) and its derivative should be 0, automatic
differentiation returns a NaN derivative because of a -inf - (-inf) operation
appearing in the definition of centered_exp below. This is corrected in the
implementation below.
Args:
axis: argument from original logsumexp
keepdims: argument from original logsumexp
return_sign: argument from original logsumexp
primals: mat and b, the two arguments against which we differentiate.
tangents: of same size as mat and b.
Returns:
original primal outputs + their tangent.
""" # noqa: D401
mat, b = primals
tan_mat, tan_b = tangents
lse = logsumexp(mat, axis, keepdims, b, return_sign)
if return_sign:
lse, sign = lse
lse = jnp.where(jnp.isfinite(lse), lse, 0.0)
centered_exp = jnp.exp(mat - jnp.expand_dims(lse, axis=axis))
if b is None:
res = jnp.sum(centered_exp * tan_mat, axis=axis, keepdims=keepdims)
else:
res = jnp.sum(b * centered_exp * tan_mat, axis=axis, keepdims=keepdims)
res += jnp.sum(tan_b * centered_exp, axis=axis, keepdims=keepdims)
if return_sign:
return (lse, sign), (sign * res, jnp.zeros_like(sign))
return lse, res
[docs]
@functools.partial(jax.custom_vjp, nondiff_argnums=(2,))
def softmin(
x: jnp.ndarray, gamma: float, axis: Optional[int] = None
) -> jnp.ndarray:
r"""Soft-min operator.
Args:
x: Input data.
gamma: Smoothing parameter :math:`> 0`.
axis: Axis or axes over which to operate. If ``None``, use flattened input.
Returns:
The soft minimum.
"""
return -gamma * jsp.special.logsumexp(x / -gamma, axis=axis)
softmin.defvjp(
lambda x, gamma, axis: (softmin(x, gamma, axis), (x / -gamma, axis)),
lambda axis, res, g: (
jnp.where(
jnp.isinf(res[0]), 0.0,
jax.nn.softmax(res[0], axis=axis) *
(g if axis is None else jnp.expand_dims(g, axis=axis))
), None
)
)
@functools.partial(jax.vmap, in_axes=[0, 0, None])
def barycentric_projection(
matrix: jnp.ndarray, y: jnp.ndarray, cost_fn: "costs.CostFn"
) -> jnp.ndarray:
"""Compute the barycentric projection of a matrix.
Args:
matrix: a matrix of shape (n, m)
y: a vector of shape (m,)
cost_fn: a CostFn instance.
Returns:
a vector of shape (n,) containing the barycentric projection of matrix.
"""
return jax.vmap(
lambda m, y: cost_fn.barycenter(m, y)[0], in_axes=[0, None]
)(matrix, y)
def sort_and_argsort(
x: jnp.array,
*,
argsort: bool = False
) -> Tuple[jnp.ndarray, Optional[jnp.ndarray]]:
"""Unified function that returns both sort and argsort, if latter needed."""
if argsort:
i_x = jnp.argsort(x)
return x[i_x], i_x
return jnp.sort(x), None
[docs]
@functools.partial(jax.custom_jvp, nondiff_argnums=(1, 2))
def lambertw(
z: jnp.ndarray, tol: float = 1e-8, max_iter: int = 100
) -> jnp.ndarray:
"""Principal branch of the
`Lambert W function <https://en.wikipedia.org/wiki/Lambert_W_function>`_.
This implementation uses Halley's iteration and the global initialization
proposed in :cite:`iacono:17`, Eq. 20 .
Args:
z: Array.
tol: Tolerance threshold.
max_iter: Maximum number of iterations.
Returns:
The Lambert W evaluated at ``z``.
""" # noqa: D205
def initial_iacono(x: jnp.ndarray) -> jnp.ndarray:
y = jnp.sqrt(1.0 + jnp.e * x)
num = 1.0 + 1.14956131 * y
denom = 1.0 + 0.45495740 * jnp.log1p(y)
return -1.0 + 2.036 * jnp.log(num / denom)
def cond_fun(container):
it, converged, _ = container
return jnp.logical_and(jnp.any(~converged), it < max_iter)
def halley_iteration(container):
it, _, w = container
# modified from `tensorflow_probability`
f = w - z * jnp.exp(-w)
delta = f / (w + 1.0 - 0.5 * (w + 2.0) * f / (w + 1.0))
w_next = w - delta
not_converged = jnp.abs(delta) <= tol * jnp.abs(w_next)
return it + 1, not_converged, w_next
w0 = initial_iacono(z)
converged = jnp.zeros_like(w0, dtype=bool)
_, _, w = jax.lax.while_loop(
cond_fun=cond_fun, body_fun=halley_iteration, init_val=(0, converged, w0)
)
return w
@lambertw.defjvp
def _lambertw_jvp(
tol: float, max_iter: int, primals: Tuple[jnp.ndarray, ...],
tangents: Tuple[jnp.ndarray, ...]
) -> Tuple[jnp.ndarray, jnp.ndarray]:
z, = primals
dz, = tangents
w = lambertw(z, tol=tol, max_iter=max_iter)
pz = jnp.where(z == 0.0, 1.0, w / ((1.0 + w) * z))
return w, pz * dz