ott.neural.networks.potentials.PotentialMLP.potential_gradient_fn

ott.neural.networks.potentials.PotentialMLP.potential_gradient_fn#

PotentialMLP.potential_gradient_fn()#

Return a callable giving the gradient of the potential.

For potential models, returns vmap(grad(self)). For gradient models, returns self directly.

Return type:

Callable[[Array], Array]