ott.problems.quadratic.quadratic_problem.QuadraticProblem#

class ott.problems.quadratic.quadratic_problem.QuadraticProblem(geom_xx, geom_yy, geom_xy=None, fused_penalty=1.0, scale_cost=False, a=None, b=None, loss='sqeucl', tau_a=1.0, tau_b=1.0, gw_unbalanced_correction=True, ranks=- 1, tolerances=0.01)[source]#

Quadratic OT problem.

The quadratic loss of a single OT matrix is assumed to have the form given in [Peyré et al., 2016], eq. 4.

The two geometries below parameterize matrices \(C\) and \(\bar{C}\) in that equation. The function \(L\) (of two real values) in that equation is assumed to match the form given in eq. 5., with our notations:

\[L(x, y) = lin1(x) + lin2(y) - quad1(x) * quad2(y)\]
Parameters
  • geom_xx (Geometry) – Ground geometry of the first space.

  • geom_yy (Geometry) – Ground geometry of the second space.

  • geom_xy (Optional[Geometry]) – Geometry defining the linear penalty term for Fused Gromov-Wasserstein. If None, the problem reduces to a plain Gromov-Wasserstein problem.

  • fused_penalty (float) – multiplier of the linear term in Fused Gromov-Wasserstein, i.e. problem = purely quadratic + fused_penalty * linear problem. Ignored if geom_xy is not specified.

  • scale_cost (Union[bool, float, str, None]) –

    option to rescale the cost matrices:

    • if True, use the default for each geometry.

    • if False, keep the original scaling in geometries.

    • if str, use a specific method available in Geometry or PointCloud.

    • if None, do not scale the cost matrices.

  • a (Optional[Array]) – array representing the probability weights of the samples from geom_xx. If None, it will be uniform.

  • b (Optional[Array]) – array representing the probability weights of the samples from geom_yy. If None, it will be uniform.

  • loss (Union[Literal[‘sqeucl’, ‘kl’], GWLoss]) – a 2-tuple of 2-tuples of Callable. The first tuple is the linear part of the loss. The second one is the quadratic part (quad1, quad2). By default, the loss is set as the 4 functions representing the squared Euclidean loss, and this property is taken advantage of in subsequent computations. Alternatively, KL loss can be specified in no less optimized way.

  • tau_a (Optional[float]) – if < 1.0, defines how much unbalanced the problem is on the first marginal.

  • tau_b (Optional[float]) – if < 1.0, defines how much unbalanced the problem is on the second marginal.

  • gw_unbalanced_correction (bool) – Whether the unbalanced version of [Sejourne et al., 2021] is used. Otherwise, tau_a and tau_b only affect the inner Sinkhorn loop.

  • ranks (Union[int, Tuple[int, ...]]) – Ranks of the cost matrices, see to_LRCGeometry(). Used when geometries are not PointCloud with ‘sqeucl’ cost function. If -1, the geometries will not be converted to low-rank. If tuple, it specifies the ranks of geom_xx, geom_yy and geom_xy, respectively. If int, rank is shared across all geometries.

  • tolerances (Union[float, Tuple[float, ...]]) – Tolerances used when converting geometries to low-rank. Used when geometries are not PointCloud with ‘sqeucl’ cost. If float, it is shared across all geometries.

Methods

cost_unbalanced_correction(transport_matrix, ...)

Calculate cost term from the quadratic divergence when unbalanced.

init_transport_mass()

Initialise the transport mass.

marginal_dependent_cost(marginal_1, marginal_2)

Initialise cost term that depends on the marginals of the transport.

to_low_rank([seed])

Convert geometries to low-rank.

update_linearization(transport[, epsilon, ...])

Update linearization of GW problem by updating cost matrix.

update_lr_geom(lr_sink)

Recompute (possibly LRC) linearization using LR Sinkhorn output.

update_lr_linearization(lr_sink)

Update a Quad problem linearization using a LR Sinkhorn.

Attributes

a

First marginal.

b

Second marginal.

geom_xx

Geometry of the first space.

geom_xy

Geometry of the joint space.

geom_yy

Geometry of the second space.

is_balanced

Whether the problem is balanced.

is_fused

Whether the problem is fused.

is_low_rank

Whether all geometries are low-rank.

linear_loss

Linear part of the Gromov-Wasserstein loss.

quad_loss

Quadratic part of the Gromov-Wasserstein loss.