ott.neural.networks#

Networks#

icnn.ICNN(dim_hidden, *, input_dim[, ...])

Input convex neural network (ICNN).

icnn.KeyNet(dim_hidden, *, input_dim[, ...])

Vector-output network with ICNN-like architecture.

potentials.BasePotential([parent, name])

Base class for the neural solver models (Linen).

potentials.PotentialMLP(dim_hidden, *, input_dim)

Potential MLP (NNX).

potentials.MLP(dim_hidden, *, input_dim[, ...])

A simple MLP (NNX).

potentials.PotentialTrainState(step, ...)

Adds information about the model's value and gradient to the state.

ott.neural.networks.velocity_field#

MLP#

mlp.MLP(dim, *[, hidden_dims, cond_dim, ...])

MLP velocity field.

UNet#

unet.UNet(*, shape, model_channels, ...[, ...])

UNet model with attention and timestep embedding.

EMA#

ema.EMA(model, *, decay)

Exponential moving average (EMA) of a model.

ema.init_ema(model)

Create initial exponential moving average (EMA) state.

ema.update_ema(model, *, ema, decay)

Update the EMA of a model.

ott.neural.networks.layers#

Layers#

conjugate.FenchelConjugateSolver()

Abstract conjugate solver class.

conjugate.FenchelConjugateLBFGS([gtol, ...])

Solve for the conjugate using LBFGS.

conjugate.ConjugateResults(val, grad, num_iter)

Holds the results of numerically conjugating a function.

posdef.PositiveDense(in_features, ...[, ...])

A linear transformation with non-negative weights.

posdef.PosDefPotentials(in_features, ...[, ...])

Low-rank plus diagonal positive definite quadratic potentials.