ott.neural.networks.potentials.PotentialMLP.potential_value_fn

ott.neural.networks.potentials.PotentialMLP.potential_value_fn#

PotentialMLP.potential_value_fn(other_potential_value_fn=None)#

Return a callable giving the potential value.

For potential models (is_potential=True), this simply calls the model. For gradient models, the value is reconstructed via the envelope theorem:

\[g(y) = -f(\nabla_y g(y)) + y^T \nabla_y g(y)\]
Parameters:

other_potential_value_fn (Optional[Callable[[Array], Array]]) – value function of the other potential. Required when is_potential=False.

Return type:

Callable[[Array], Array]

Returns:

A callable x -> scalar (or batched).