class ott.problems.quadratic.quadratic_problem.QuadraticProblem(geom_xx, geom_yy, geom_xy=None, fused_penalty=1.0, scale_cost=False, a=None, b=None, loss='sqeucl', tau_a=1.0, tau_b=1.0, gw_unbalanced_correction=True, ranks=- 1, tolerances=0.01)[source]#

The quadratic loss of a single OT matrix is assumed to have the form given in , eq. 4.

The two geometries below parameterize matrices $$C$$ and $$\bar{C}$$ in that equation. The function $$L$$ (of two real values) in that equation is assumed to match the form given in eq. 5., with our notations:

$L(x, y) = lin1(x) + lin2(y) - quad1(x) * quad2(y)$
Parameters

Methods

 cost_unbalanced_correction(transport_matrix, ...) Calculate cost term from the quadratic divergence when unbalanced. Initialise the transport mass. marginal_dependent_cost(marginal_1, marginal_2) Initialise cost term that depends on the marginals of the transport. to_low_rank([seed]) Convert geometries to low-rank. update_linearization(transport[, epsilon, ...]) Update linearization of GW problem by updating cost matrix. update_lr_geom(lr_sink) Recompute (possibly LRC) linearization using LR Sinkhorn output. update_lr_linearization(lr_sink) Update a Quad problem linearization using a LR Sinkhorn.

Attributes

 a First marginal. b Second marginal. geom_xx Geometry of the first space. geom_xy Geometry of the joint space. geom_yy Geometry of the second space. is_balanced Whether the problem is balanced. is_fused Whether the problem is fused. is_low_rank Whether all geometries are low-rank. linear_loss Linear part of the Gromov-Wasserstein loss. quad_loss Quadratic part of the Gromov-Wasserstein loss.