class ott.core.quad_problems.QuadraticProblem(geom_xx, geom_yy, geom_xy=None, fused_penalty=1.0, scale_cost=False, a=None, b=None, loss=None, tau_a=1.0, tau_b=1.0, gw_unbalanced_correction=True)[source]#

Definition of the quadratic regularized OT problem.

The quadratic loss of a single OT matrix is assumed to have the form given in Eq. 4 from

http://proceedings.mlr.press/v48/peyre16.pdf

The two geometries below parameterize matrices C and bar{C} in that equation. The function L (of two real values) in that equation is assumed to match the form given in Eq. 5., with our notations:

 `cost_unbalanced_correction`(transport_matrix, ...) Calculate cost term from the quadratic divergence when unbalanced. `init_linearization`([epsilon]) Initialise a linear problem locally around a naive initializer ab'. `init_lr_linearization`([rank]) Linearizes a Quad problem with a predefined initializer. Initialise the transport matrix. Initialise the transport mass. `marginal_dependent_cost`(marginal_1, marginal_2) Initialise cost term that depends on the marginals of the transport. `update_linearization`(transport[, epsilon, ...]) Update linearization of GW problem by updating cost matrix. `update_lr_geom`(lr_sink) Recompute (possibly LRC) linearization using LR Sinkhorn output. `update_lr_linearization`(lr_sink) Update a Quad problem linearization using a LR Sinkhorn.
 `a` rtype `ndarray` `b` rtype `ndarray` `is_all_geoms_lr` rtype `bool` `is_balanced` rtype `bool` `is_fused` rtype `bool` `linear_loss` rtype `quad_loss` rtype