# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
from typing import (
Any,
Callable,
Dict,
Literal,
Optional,
Sequence,
Tuple,
Union,
)
import jax
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import numpy as np
import diffrax
import optax
from flax import nnx
__all__ = [
"flow_matching_step",
"interpolate_samples",
"evaluate_velocity_field",
"curvature",
"gaussian_nll",
]
DivState = Tuple[jax.Array, jax.Array] # velocity, divergence
Batch = Dict[Literal["t", "x_t", "v_t", "cond"], jax.Array]
[docs]
def flow_matching_step(
model: nnx.Module,
optimizer: nnx.Optimizer,
batch: Batch,
*,
loss_fn: Callable[[jax.Array, jax.Array], jax.Array] = optax.squared_error,
model_callback_fn: Optional[Callable[[nnx.Module], None]] = None,
rngs: Optional[nnx.Rngs] = None,
) -> Dict[Literal["loss", "grad_norm"], jax.Array]:
"""Perform a flow matching step.
Args:
model: Velocity field with a signature ``(t, x_t, cond, rngs=...) -> v_t``.
optimizer: Optimizer.
batch: Batch containing the following elements:
- ``'t'`` - time, array of shape ``[batch,]``.
- ``'x_t'`` - position, array of shape ``[batch, ...]``.
- ``'v_t'`` - target velocity, array of shape ``[batch, ...]``.
- ``'cond'`` - condition (optional), array of shape ``[batch, ...]``.
loss_fn: Loss function with a signature ``(pred, target) -> loss``.
model_callback_fn: Function with a signature ``(model) -> None``, e.g., to
update an :class:`~ott.neural.networks.velocity_field.EMA` of the model.
rngs: Random number generator used for, e.g., dropout, passed to the model.
Returns:
Updates the parameters in-place and returns the loss and the gradient norm.
"""
def compute_loss(model: nnx.Module, rngs: nnx.Rngs) -> jax.Array:
t, x_t, v_t = batch["t"], batch["x_t"], batch["v_t"]
cond = batch.get("cond")
v_pred = model(t, x_t, cond, rngs=rngs)
return loss_fn(v_pred, v_t).mean()
loss, grads = nnx.value_and_grad(compute_loss)(model, rngs)
if "model" in inspect.signature(optimizer.update).parameters:
optimizer.update(model, grads)
else:
# for flax version < 0.11.0
optimizer.update(grads)
grad_norm = optax.global_norm(grads)
if model_callback_fn is not None:
model_callback_fn(model)
return {"loss": loss, "grad_norm": grad_norm}
[docs]
def interpolate_samples(
rng: jax.Array,
x0: jax.Array,
x1: jax.Array,
cond: Optional[jax.Array] = None,
*,
time_sampler: Optional[Callable[[jax.Array, Tuple[int], jnp.dtype],
jax.Array]] = None
) -> Batch:
"""Sample time and interpolate.
Args:
rng: Random number generator.
x0: Source samples at :math:`t_0`, array of shape ``[batch, ...]``.
x1: Target samples at :math:`t_1`, array of shape ``[batch, ...]``.
cond: Condition.
time_sampler: Time sampler with signature ``(rng, shape, dtype) -> time``.
Returns:
Dictionary containing the following values:
- ``'t'`` - time, array of shape ``[batch,]``.
- ``'x_t'`` - position :math:`x_t`, array of shape ``[batch, ...]``.
- ``'v_t'`` - target velocity :math:`x_1 - x_0`,
array of shape ``[batch, ...]``.
- ``'cond'`` - condition (optional), array of shape ``[batch, ...]``.
"""
if time_sampler is None:
time_sampler = jr.uniform
batch_size = len(x0)
t = time_sampler(rng, (batch_size,), x0.dtype)
assert t.shape == (batch_size,), (t.shape, (batch_size,))
t_ = jnp.expand_dims(t, axis=range(1, x0.ndim))
batch = {
"t": t,
"x_t": (1.0 - t_) * x0 + t_ * x1,
"v_t": x1 - x0,
}
if cond is not None:
batch["cond"] = cond
return batch
[docs]
def evaluate_velocity_field(
model: nnx.Module,
x: Union[jax.Array, Any],
cond: Optional[jax.Array] = None,
*,
t0: float = 0.0,
t1: float = 1.0,
reverse: bool = False,
num_steps: Optional[int] = None,
solver: Optional[diffrax.AbstractSolver] = None,
save_trajectory_kwargs: Optional[Dict[str, Any]] = None,
save_velocity_kwargs: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> diffrax.Solution:
"""Solve an ODE.
Args:
model: Velocity field with a signature ``(t, x_t, cond) -> v_t``.
x: Initial point of shape ``[*dims]``.
cond: Condition of shape ``[*cond_dims]``.
t0: Start time of the integration.
t1: End time of the integration.
reverse: Whether to integrate from :math:`t_1` to :math:`t_0`.
num_steps: Number of steps used for solvers with a constant step size.
solver: ODE solver. If :obj:`None` and ``step_size = None``,
use :class:`~diffrax.Dopri5`. Otherwise use :class:`~diffrax.Euler`.
save_velocity_kwargs: Keyword arguments for :class:`~diffrax.SubSaveAt`
used to store the velocities along the integration path.
The velocity will be saved in :class:`out.ys['v_t'] <diffrax.Solution>`.
save_trajectory_kwargs: Keyword arguments for :class:`~diffrax.SubSaveAt`
used to store the positions along the integration path.
The trajectory will be saved in :class:`out.ys['x_t'] <diffrax.Solution>`.
kwargs: Keyword arguments for :func:`~diffrax.diffeqsolve`.
Returns:
The ODE solution.
"""
if isinstance(num_steps, int):
step_size = 1.0 / num_steps
stepsize_controller = diffrax.ConstantStepSize()
solver = diffrax.Euler() if solver is None else solver
kwargs["max_steps"] = num_steps
else:
step_size = None
stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5)
solver = diffrax.Dopri5() if solver is None else solver
if reverse:
step_size = None if step_size is None else -step_size
t0, t1 = t1, t0
default_velocity_fn = jtu.Partial(_velocity, model=model)
# internally, we allow for passing custom velocity functions:
# this is used when computing the gaussian NLL, as we need to
# both integrate the state and the divergence of the velocity field
velocity_fn = kwargs.pop("_velocity_fn", default_velocity_fn)
subs = {}
if save_velocity_kwargs:
saveat = diffrax.SubSaveAt(fn=default_velocity_fn, **save_velocity_kwargs)
subs["v_t"] = saveat
if save_trajectory_kwargs:
saveat = diffrax.SubSaveAt(
fn=lambda _, x_t, __: x_t, **save_trajectory_kwargs
)
subs["x_t"] = saveat
if subs:
kwargs["saveat"] = diffrax.SaveAt(subs=subs)
return diffrax.diffeqsolve(
diffrax.ODETerm(velocity_fn),
t0=t0,
t1=t1,
y0=x,
args=cond,
solver=solver,
dt0=step_size,
stepsize_controller=stepsize_controller,
**kwargs,
)
[docs]
def curvature(
model: nnx.Module,
x0: jax.Array,
cond: Optional[jax.Array] = None,
*,
ts: Union[int, jax.Array, Sequence[float]],
drop_last_velocity: Optional[bool] = None,
loss_fn: Callable[[jax.Array, jax.Array], jax.Array] = optax.squared_error,
**kwargs: Any,
) -> Tuple[jax.Array, diffrax.Solution]:
"""Compute the curvature :cite:`lee:23`.
Also known as straightness in :cite:`liu:22`.
Args:
model: Velocity field with a signature ``(t, x_t, cond) -> v_t``.
x0: Initial point of shape ``[*dims]``.
cond: Condition of shape ``[*cond_dims]``.
ts: Time points at which velocities are computed and stored.
If :class:`int`, use linearly-spaced interval ``[t0, t1]``
with ``ts`` steps.
drop_last_velocity: Whether to remove the velocity at ``ts[-1]``.
when computing the curvature. If :obj:`None`, don't include it when
``ts[-1] == 1.0``.
loss_fn: Loss function with a signature ``(pred, target) -> loss``.
kwargs: Keyword arguments for :func:`evaluate_velocity_field`.
Returns:
The curvature and the ODE solution.
"""
if isinstance(ts, int):
assert ts > 0, f"Number of steps must be positive, got {ts}."
t0, t1 = kwargs.get("t0", 0.0), kwargs.get("t1", 1.0)
ts = np.linspace(t0, t1, ts)
if drop_last_velocity is None:
drop_last_velocity = ts[-1] == 1.0
sol = evaluate_velocity_field(
model,
x0,
cond,
reverse=False,
save_trajectory_kwargs={"t1": True}, # save only at `t1`
save_velocity_kwargs={"ts": ts}, # save `v_t` at specified times
**kwargs,
)
x1 = sol.ys["x_t"][-1]
v_t = sol.ys["v_t"][:-1] if drop_last_velocity else sol.ys["v_t"]
steps = len(ts) - drop_last_velocity
assert x0.shape == x1.shape, (x0.shape, x1.shape)
assert v_t.shape == (steps, *x0.shape), (v_t.shape, (steps, *x0.shape))
ref_velocity = (x1 - x0)
curv = jax.vmap(loss_fn, in_axes=[0, None])(v_t, ref_velocity).mean()
return curv, sol
[docs]
def gaussian_nll(
model: nnx.Module,
x1: jax.Array,
cond: Optional[jax.Array] = None,
*,
noise: Optional[jax.Array] = None,
stddev: float = 1.0,
**kwargs: Any,
) -> Tuple[jax.Array, diffrax.Solution]:
"""Compute the Gaussian negative log-likelihood.
Args:
model: Velocity model with a signature ``(t, x_t, cond) -> v_t``.
x1: Initial point of shape ``[*dims]``.
cond: Condition ``[*cond_dims]``.
noise: Array of shape ``[num_noise_samples, ...]`` used for the Hutchinson's
trace estimate of the divergence of the velocity field. If :obj:`None`,
compute the exact divergence using :func:`jax.jacrev`.
stddev: Standard deviation of the Gaussian distribution.
kwargs: Keyword arguments for :func:`evaluate_velocity_field`.
Returns:
The Gaussian negative log-likelihood in bits-per-dimension.
"""
if noise is not None:
_, *noise_shape = noise.shape # [batch, ...]
assert x1.shape == tuple(noise_shape), (x1.shape, noise_shape)
velocity_fn = functools.partial(_hutchinson_divergence, h=noise)
else:
velocity_fn = _exact_divergence
sol = evaluate_velocity_field(
model,
(x1, jnp.zeros([])), # initial point, divergence
cond,
reverse=True,
saveat=diffrax.SaveAt(t1=True),
save_trajectory_kwargs=None,
save_velocity_kwargs=None,
_velocity_fn=jtu.Partial(velocity_fn, model=model),
**kwargs,
)
x0, neg_int01_div_v = sol.ys
assert x0.shape == (1, *x1.shape), (x0.shape, (1, *x1.shape))
assert neg_int01_div_v.shape == (1,), neg_int01_div_v.shape
k = np.prod(x0.shape)
logp0_x0 = -0.5 * ((x0 / stddev) ** 2).sum()
logp0_x0 = logp0_x0 - 0.5 * k * jnp.log(2.0 * jnp.pi) - k * jnp.log(stddev)
nll = -(logp0_x0 + neg_int01_div_v[0])
return nll, sol
def _velocity(
t: jax.Array, x_t: jax.Array, cond: Optional[jax.Array], model: nnx.Module
) -> jax.Array:
cond = None if cond is None else cond[None]
return model(t[None], x_t[None], cond).squeeze(0)
def _exact_divergence(
t: jax.Array, state_t: DivState, cond: Optional[jax.Array], *,
model: nnx.Module
) -> DivState:
def divergence_v(
t: jax.Array, x: jax.Array, cond: Optional[jax.Array]
) -> jax.Array:
# divergence of fwd velocity field
jacobian = jax.jacrev(_velocity, argnums=1)(t, x, cond, model)
jacobian = jacobian.reshape(np.prod(x.shape), np.prod(x.shape))
return jnp.trace(jacobian)
x_t, _ = state_t
v_t = _velocity(t, x_t, cond, model=model)
div_t = divergence_v(t, x_t, cond)
return v_t, div_t
def _hutchinson_divergence(
t: jax.Array, state_t: DivState, cond: Optional[jax.Array], *,
model: nnx.Module, h: jax.Array
) -> DivState:
x_t, _ = state_t
v_t, vjp = jax.vjp(lambda x: _velocity(t, x, cond, model=model), x_t)
(Dvh,) = jax.vmap(vjp)(h)
div_t = jax.vmap(jnp.vdot, in_axes=[0, 0])(h, Dvh).mean()
return v_t, div_t