# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import math
from typing import Tuple
import jax
import jax.numpy as jnp
from ott.math import fixed_point_loop
__all__ = ["sqrtm", "sqrtm_only", "inv_sqrtm_only"]
[docs]
@functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3, 4, 5))
def sqrtm(
x: jnp.ndarray,
threshold: float = 1e-6,
min_iterations: int = 0,
inner_iterations: int = 10,
max_iterations: int = 1000,
regularization: float = 1e-6
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Higham algorithm to compute matrix square root of p.d. matrix.
See :cite:`higham:97`, eq. 2.6b
Args:
x: a (batch of) square p.s.d. matrices of the same size.
threshold: convergence tolerance threshold for Newton-Schulz iterations.
min_iterations: min number of iterations after which error is computed.
inner_iterations: error is re-evaluated every inner_iterations iterations.
max_iterations: max number of iterations.
regularization: small regularizer added to norm of x, before normalization.
Returns:
Square root matrix of x (or x's if batch), its inverse,
errors along iterates.
"""
dimension = x.shape[-1]
norm_x = jnp.linalg.norm(x, axis=(-2, -1)) * (1 + regularization)
if jnp.ndim(x) > 2:
norm_x = norm_x[..., jnp.newaxis, jnp.newaxis]
def cond_fn(iteration, const, state):
"""Stopping criterion. Checking decrease of objective is needed here."""
_, threshold = const
errors, _, _ = state
err = errors[iteration // inner_iterations - 1]
return jnp.logical_or(
iteration == 0,
jnp.logical_and(
jnp.logical_and(jnp.isfinite(err), err > threshold),
jnp.all(jnp.diff(errors) <= 0)
)
) # check decreasing obj, else stop
def body_fn(iteration, const, state, compute_error):
"""Carry out matrix updates on y and z, stores error if requested.
Args:
iteration: iteration number
const: tuple of constant parameters that do not change throughout the
loop.
state: state variables currently updated in the loop.
compute_error: flag to indicate this iteration computes/stores an error
Returns:
state variables.
"""
x, _ = const
errors, y, z = state
w = 0.5 * jnp.matmul(z, y)
y = 1.5 * y - jnp.matmul(y, w)
z = 1.5 * z - jnp.matmul(w, z)
err = jnp.where(compute_error, new_err(x, norm_x, y), jnp.inf)
errors = errors.at[iteration // inner_iterations].set(err)
return errors, y, z
def new_err(x, norm_x, y):
res = x - norm_x * jnp.matmul(y, y)
norm_fn = functools.partial(jnp.linalg.norm, axis=(-2, -1))
return jnp.max(norm_fn(res) / norm_fn(x))
y = x / norm_x
z = jnp.eye(dimension)
if jnp.ndim(x) > 2:
z = jnp.tile(z, list(x.shape[:-2]) + [1, 1])
errors = -jnp.ones(math.ceil(max_iterations / inner_iterations))
state = (errors, y, z)
const = (x, threshold)
errors, y, z = fixed_point_loop.fixpoint_iter_backprop(
cond_fn, body_fn, min_iterations, max_iterations, inner_iterations, const,
state
)
sqrt_x = jnp.sqrt(norm_x) * y
inv_sqrt_x = z / jnp.sqrt(norm_x)
return sqrt_x, inv_sqrt_x, errors
def solve_sylvester_bartels_stewart(
a: jnp.ndarray,
b: jnp.ndarray,
c: jnp.ndarray,
) -> jnp.ndarray:
"""Solve the real Sylvester equation AX - XB = C using Bartels-Stewart."""
# See https://nhigham.com/2020/09/01/what-is-the-sylvester-equation/ for
# discussion of the algorithm (but note that in the derivation, the sign on
# the right hand side is flipped in the equation in which the columns are set
# to be equal).
m = a.shape[-1]
n = b.shape[-1]
# Cast a and b to complex to ensure we get the complex Schur decomposition
# (the real Schur decomposition may not give an upper triangular solution).
# For the decomposition below, a = u r u* and b = v s v*
r, u = jax.lax.linalg.schur(a + 0j)
s, v = jax.lax.linalg.schur(b + 0j)
d = jnp.matmul(
jnp.conjugate(jnp.swapaxes(u, axis1=-2, axis2=-1)), jnp.matmul(c, v)
)
# The solution in the transformed space will in general be complex, too.
y = jnp.zeros(a.shape[:-2] + (m, n)) + 0j
idx = jnp.arange(m)
for j in range(n):
lhs = r.at[..., idx, idx].add(-s[..., j:j + 1, j])
rhs = d[..., j] + jnp.matmul(y[..., :j], s[..., :j, j:j + 1])[..., 0]
y = y.at[..., j].set(jax.scipy.linalg.solve_triangular(lhs, rhs))
x = jnp.matmul(
u, jnp.matmul(y, jnp.conjugate(jnp.swapaxes(v, axis1=-2, axis2=-1)))
)
# The end result should be real; remove the imaginary part of the solution.
return jnp.real(x)
def sqrtm_fwd(
x: jnp.ndarray,
threshold: float,
min_iterations: int,
inner_iterations: int,
max_iterations: int,
regularization: float,
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray,
jnp.ndarray]]:
"""Forward pass of custom VJP."""
sqrt_x, inv_sqrt_x, errors = sqrtm(
x=x,
threshold=threshold,
min_iterations=min_iterations,
inner_iterations=inner_iterations,
max_iterations=max_iterations,
regularization=regularization,
)
return (sqrt_x, inv_sqrt_x, errors), (sqrt_x, inv_sqrt_x)
def sqrtm_bwd(
threshold: float,
min_iterations: int,
inner_iterations: int,
max_iterations: int,
regularization: float,
residual: Tuple[jnp.ndarray, jnp.ndarray],
cotangent: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray],
) -> Tuple[jnp.ndarray]:
"""Compute the derivative by solving a Sylvester equation."""
del threshold, min_iterations, inner_iterations, \
max_iterations, regularization
sqrt_x, inv_sqrt_x = residual
# ignores cotangent associated with errors
cot_sqrt, cot_inv_sqrt, _ = cotangent
# Solve for d(X^{1/2}):
# Start with X^{1/2} X^{1/2} = X
# Differentiate to obtain
# d(X^{1/2}) X^{1/2} + X^{1/2} d(X^{1/2}) = dX
# The above is a Sylvester equation that we can solve using Bartels-Stewart.
# Below think of cot_sqrt as (dX)^T and vjp_cot_sqrt as d(X^{1/2})^T.
# See https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html
vjp_cot_sqrt = jnp.swapaxes(
solve_sylvester_bartels_stewart(
a=sqrt_x, b=-sqrt_x, c=jnp.swapaxes(cot_sqrt, axis1=-1, axis2=-2)
),
axis1=-1,
axis2=-2
)
# Now solve for d(X^{-1/2}):
# Start with X^{-1/2} X^{-1/2} = X^{-1}
# Use the product rule and the fact that d(X^{-1}) = -X^{-1} dX X^{-1}
# to obtain
# (See The Matrix Cookbook section on derivatives of an inverse
# https://www.math.uwaterloo.ca/~hwolkowi/matrixcookbook.pdf )
# d(X^{-1/2}) X^{-1/2} + X^{-1/2} d(X^{-1/2}) = -X^{-1} dX X^{-1}
# Again we have a Sylvester equation that we solve as above, and again we
# think of cot_inv_sqrt as (dX)^T and vjp_cot_inv_sqrt as d(X^{-1/2})^T
inv_x = jnp.matmul(inv_sqrt_x, inv_sqrt_x)
vjp_cot_inv_sqrt = jnp.swapaxes(
solve_sylvester_bartels_stewart(
a=inv_sqrt_x,
b=-inv_sqrt_x,
c=-jnp.matmul(
inv_x,
jnp.matmul(jnp.swapaxes(cot_inv_sqrt, axis1=-2, axis2=-1), inv_x)
)
),
axis1=-1,
axis2=-2
)
return vjp_cot_sqrt + vjp_cot_inv_sqrt,
sqrtm.defvjp(sqrtm_fwd, sqrtm_bwd)
# Specialized versions of sqrtm that compute only the square root or inverse.
# These functions have lower complexity gradients than sqrtm.
@functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3, 4, 5))
def sqrtm_only( # noqa: D103
x: jnp.ndarray,
threshold: float = 1e-6,
min_iterations: int = 0,
inner_iterations: int = 10,
max_iterations: int = 1000,
regularization: float = 1e-6
) -> jnp.ndarray:
return sqrtm(
x, threshold, min_iterations, inner_iterations, max_iterations,
regularization
)[0]
def sqrtm_only_fwd( # noqa: D103
x: jnp.ndarray, threshold: float, min_iterations: int,
inner_iterations: int, max_iterations: int, regularization: float
) -> Tuple[jnp.ndarray, jnp.ndarray]:
sqrt_x = sqrtm(
x, threshold, min_iterations, inner_iterations, max_iterations,
regularization
)[0]
return sqrt_x, sqrt_x
def sqrtm_only_bwd( # noqa: D103
threshold: float, min_iterations: int, inner_iterations: int,
max_iterations: int, regularization: float, sqrt_x: jnp.ndarray,
cotangent: jnp.ndarray
) -> Tuple[jnp.ndarray]:
del threshold, min_iterations, inner_iterations, \
max_iterations, regularization
vjp = jnp.swapaxes(
solve_sylvester_bartels_stewart(
a=sqrt_x, b=-sqrt_x, c=jnp.swapaxes(cotangent, axis1=-2, axis2=-1)
),
axis1=-2,
axis2=-1
)
return vjp,
sqrtm_only.defvjp(sqrtm_only_fwd, sqrtm_only_bwd)
@functools.partial(jax.custom_vjp, nondiff_argnums=(1, 2, 3, 4, 5))
def inv_sqrtm_only( # noqa: D103
x: jnp.ndarray,
threshold: float = 1e-6,
min_iterations: int = 0,
inner_iterations: int = 10,
max_iterations: int = 1000,
regularization: float = 1e-6
) -> jnp.ndarray:
return sqrtm(
x, threshold, min_iterations, inner_iterations, max_iterations,
regularization
)[1]
def inv_sqrtm_only_fwd( # noqa: D103
x: jnp.ndarray,
threshold: float,
min_iterations: int,
inner_iterations: int,
max_iterations: int,
regularization: float,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
inv_sqrt_x = sqrtm(
x, threshold, min_iterations, inner_iterations, max_iterations,
regularization
)[1]
return inv_sqrt_x, inv_sqrt_x
def inv_sqrtm_only_bwd( # noqa: D103
threshold: float, min_iterations: int, inner_iterations: int,
max_iterations: int, regularization: float, residual: jnp.ndarray,
cotangent: jnp.ndarray
) -> Tuple[jnp.ndarray]:
del threshold, min_iterations, inner_iterations, \
max_iterations, regularization
inv_sqrt_x = residual
inv_x = jnp.matmul(inv_sqrt_x, inv_sqrt_x)
vjp = jnp.swapaxes(
solve_sylvester_bartels_stewart(
a=inv_sqrt_x,
b=-inv_sqrt_x,
c=-jnp.matmul(
inv_x,
jnp.matmul(jnp.swapaxes(cotangent, axis1=-2, axis2=-1), inv_x)
)
),
axis1=-1,
axis2=-2
)
return vjp,
inv_sqrtm_only.defvjp(inv_sqrtm_only_fwd, inv_sqrtm_only_bwd)