ott.neural.models.MLP.potential_gradient_fn

ott.neural.models.MLP.potential_gradient_fn#

MLP.potential_gradient_fn(params)#

Return a function returning a vector or the gradient of the potential.

Parameters:

params (FrozenDict[str, Array]) – parameters of the module

Return type:

Callable[[Array], Array]

Returns:

A function that can be evaluated to obtain the potential’s gradient