ott.neural.solvers.neuraldual.BaseW2NeuralDual.potential_gradient_fn

ott.neural.solvers.neuraldual.BaseW2NeuralDual.potential_gradient_fn#

BaseW2NeuralDual.potential_gradient_fn(params)[source]#

Return a function returning a vector or the gradient of the potential.

Parameters:

params (FrozenDict[str, Array]) – parameters of the module

Return type:

Callable[[Array], Array]

Returns:

A function that can be evaluated to obtain the potential’s gradient