class ott.problems.quadratic.quadratic_problem.QuadraticProblem(geom_xx, geom_yy, geom_xy=None, fused_penalty=1.0, scale_cost=None, a=None, b=None, loss='sqeucl', tau_a=1.0, tau_b=1.0, gw_unbalanced_correction=True, ranks=-1, tolerances=0.01)[source]#

Quadratic OT problem.

The quadratic loss of a single OT matrix is assumed to have the form given in [Peyré et al., 2016], eq. 4.

The two geometries below parameterize matrices \(C\) and \(\bar{C}\) in that equation. The function \(L\) (of two real values) in that equation is assumed to match the form given in eq. 5., with our notations:

\[L(x, y) = f_1(x) + f_2(y) - h_1(x) h_2(y)\]
  • geom_xx (Geometry) – Ground geometry of the first space.

  • geom_yy (Geometry) – Ground geometry of the second space.

  • geom_xy (Optional[Geometry]) – Geometry defining the linear penalty term for fused Gromov-Wasserstein [Titouan et al., 2019]. If None, the problem reduces to a plain Gromov-Wasserstein problem [Peyré et al., 2016].

  • fused_penalty (float) – Multiplier of the linear term in fused Gromov-Wasserstein, i.e. problem = purely quadratic + fused_penalty * linear problem.

  • scale_cost (Union[float, str, None]) – How to rescale the cost matrices. If a str, use specific options available in Geometry or PointCloud. If None, keep the original scaling.

  • a (Optional[Array]) – The first marginal. If None, it will be uniform.

  • b (Optional[Array]) – The second marginal. If None, it will be uniform.

  • loss (Union[Literal['sqeucl', 'kl'], GWLoss]) – Gromov-Wasserstein loss function, see GWLoss for more information.

  • tau_a (float) – If \(< 1.0\), defines how much unbalanced the problem is on the first marginal.

  • tau_b (float) – If \(< 1.0\), defines how much unbalanced the problem is on the second marginal.

  • gw_unbalanced_correction (bool) – Whether the unbalanced version of [Sejourne et al., 2021] is used. Otherwise, tau_a and tau_b only affect the inner Sinkhorn loop.

  • ranks (Union[int, Tuple[int, ...]]) – Ranks of the cost matrices, see to_LRCGeometry(). Used when geometries are not PointCloud with ‘sqeucl’ cost function. If -1, the geometries will not be converted to low-rank. If tuple, it specifies the ranks of geom_xx, geom_yy and geom_xy, respectively. If int, rank is shared across all geometries.

  • tolerances (Union[float, Tuple[float, ...]]) – Tolerances used when converting geometries to low-rank. Used when geometries are not PointCloud with ‘sqeucl’ cost. If float, it is shared across all geometries.


cost_unbalanced_correction(transport_matrix, ...)

Calculate cost term from the quadratic divergence when unbalanced.


Initialize the transport mass.

marginal_dependent_cost(marginal_1, marginal_2)

Initialize cost term that depends on the marginals of the transport.


Convert geometries to low-rank.

update_linearization(transport[, epsilon, ...])

Update linearization of GW problem by updating cost matrix.

update_lr_geom(lr_sink[, relative_epsilon])

Recompute (possibly LRC) linearization using LR Sinkhorn output.

update_lr_linearization(lr_sink, *[, ...])

Update a Quad problem linearization using a LR Sinkhorn.



First marginal.


Second marginal.


Geometry of the first space.


Geometry of the joint space.


Geometry of the second space.


Whether the problem is balanced.


Whether the problem is fused.


Whether all geometries are low-rank.


Linear part of the Gromov-Wasserstein loss.


Quadratic part of the Gromov-Wasserstein loss.