ott.neural.networks#
Networks#
|
Input convex neural network (ICNN). |
|
Vector-output network with ICNN-like architecture. |
|
Base class for the neural solver models (Linen). |
|
Potential MLP (NNX). |
|
A simple MLP (NNX). |
|
Adds information about the model's value and gradient to the state. |
ott.neural.networks.velocity_field#
MLP#
|
MLP velocity field. |
UNet#
|
UNet model with attention and timestep embedding. |
EMA#
|
Exponential moving average (EMA) of a model. |
|
Create initial exponential moving average (EMA) state. |
|
Update the EMA of a model. |
ott.neural.networks.layers#
Layers#
Abstract conjugate solver class. |
|
|
Solve for the conjugate using |
|
Holds the results of numerically conjugating a function. |
|
A linear transformation with non-negative weights. |
|
Low-rank plus diagonal positive definite quadratic potentials. |