ott.neural.networks#

Networks#

icnn.ICNN(dim_data, dim_hidden[, ranks, ...])

Input convex neural network (ICNN).

velocity_field.VelocityField(hidden_dims, ...)

Neural vector field.

potentials.BasePotential([parent, name])

Base class for the neural solver models.

potentials.PotentialMLP(dim_hidden[, ...])

Potential MLP.

potentials.PotentialTrainState(step, ...)

Adds information about the model's value and gradient to the state.

ott.neural.networks.layers#

Layers#

conjugate.FenchelConjugateSolver()

Abstract conjugate solver class.

conjugate.FenchelConjugateLBFGS([gtol, ...])

Solve for the conjugate using LBFGS.

conjugate.ConjugateResults(val, grad, num_iter)

Holds the results of numerically conjugating a function.

posdef.PositiveDense(dim_hidden[, ...])

A linear transformation using a matrix with all entries non-negative.

posdef.PosDefPotentials(num_potentials[, ...])

\(\frac{1}{2} x^T (A_i A_i^T + \text{Diag}(d_i)) x + b_i^T x^2 + c_i\)

time_encoder.cyclical_time_encoder(t[, n_freqs])

Encode time \(t\) into a cyclical representation.