ott.solvers.nn.models.ModelBase.potential_value_fn#

ModelBase.potential_value_fn(params, other_potential_value_fn=None)[source]#

Return a function giving the value of the potential.

Applies the module if is_potential is True, otherwise constructs the value of the potential from the gradient with

\[g(y) = -f(\nabla_y g(y)) + y^T \nabla_y g(y)\]

where \(\nabla_y g(y)\) is detached for the envelope theorem [Bertsekas, 1971, Danskin, 1967] to give the appropriate first derivatives of this construction.

Parameters:
Return type:

Callable[[Array], Array]

Returns:

A function that can be evaluated to obtain the potential’s value