ott.solvers.nn.models.MLP.potential_gradient_fn#

MLP.potential_gradient_fn(params)#

Return a function giving the gradient of the potential.

Parameters:

params (FrozenDict[str, Array]) – parameters of the module

Return type:

Callable[[Array], Array]

Returns:

A function that can be evaluated to obtain the potential’s gradient