ott.solvers.linear.implicit_differentiation.ImplicitDiff.gradient

ott.solvers.linear.implicit_differentiation.ImplicitDiff.gradient#

ImplicitDiff.gradient(prob, f, g, lse_mode, gr)[source]#

Apply VJP to recover gradient in reverse mode differentiation.

Parameters:
Return type:

LinearProblem