ott.core.quad_problems.QuadraticProblem#

class ott.core.quad_problems.QuadraticProblem(geom_xx, geom_yy, geom_xy=None, fused_penalty=1.0, scale_cost=False, a=None, b=None, loss='sqeucl', tau_a=1.0, tau_b=1.0, gw_unbalanced_correction=True, ranks=- 1, tolerances=0.01)[source]#

Quadratic regularized OT problem.

The quadratic loss of a single OT matrix is assumed to have the form given in [Peyré et al., 2016], eq. 4.

The two geometries below parameterize matrices \(C\) and \(\bar{C}\) in that equation. The function \(L\) (of two real values) in that equation is assumed to match the form given in eq. 5., with our notations:

\[L(x, y) = lin1(x) + lin2(y) - quad1(x) * quad2(y)\]
Parameters
  • geom_xx (Geometry) – Ground geometry of the first space.

  • geom_yy (Geometry) – Ground geometry of the second space.

  • geom_xy (Optional[Geometry]) – Geometry defining the linear penalty term for Fused Gromov Wasserstein. If None, the problem reduces to a plain Gromov Wasserstein problem.

  • fused_penalty (float) – multiplier of the linear term in Fused Gromov Wasserstein, i.e. problem = purely quadratic + fused_penalty * linear problem. Ignored if geom_xy is not specified.

  • scale_cost (Union[bool, float, str, None]) –

    option to rescale the cost matrices:

    • if True, use the default for each geometry.

    • if False, keep the original scaling in geometries.

    • if str, use a specific method available in Geometry or PointCloud.

    • if None, do not scale the cost matrices.

  • a (Optional[ndarray]) – jnp.ndarray[n] representing the probability weights of the samples from geom_xx. If None, it will be uniform.

  • b (Optional[ndarray]) – jnp.ndarray[n] representing the probability weights of the samples from geom_yy. If None, it will be uniform.

  • loss (Union[Literal[‘sqeucl’, ‘kl’], GWLoss]) – a 2-tuple of 2-tuples of Callable. The first tuple is the linear part of the loss (see in the pydoc of the class lin1, lin2). The second one is the quadratic part (quad1, quad2). By default, the loss is set as the 4 functions representing the squared Euclidean loss, and this property is taken advantage of in subsequent computations. See Alternatively, KL loss can be specified in no less optimized way.

  • tau_a (Optional[float]) – if lower that 1.0, defines how much unbalanced the problem is on the first marginal.

  • tau_b (Optional[float]) – if lower that 1.0, defines how much unbalanced the problem is on the second marginal.

  • gw_unbalanced_correction (bool) – Whether the unbalanced version of [Sejourne et al., 2021] is used. Otherwise tau_a and tau_b only affect the inner Sinkhorn loop.

  • ranks (Union[int, Tuple[int, ...]]) – Ranks of the cost matrices, see to_LRCGeometry(). Used when geometries are not PointCloud with ‘sqeucl’ cost function. If -1, the geometries will not be converted to low-rank. If tuple, it specifies the ranks of geom_xx, geom_yy and geom_xy, respectively. If int, rank is shared across all geometries.

  • tolerances (Union[float, Tuple[float, ...]]) – Tolerances used when converting geometries to low-rank. Used when geometries are not PointCloud with ‘sqeucl’ cost. If float, it is shared across all geometries.

Methods

cost_unbalanced_correction(transport_matrix, ...)

Calculate cost term from the quadratic divergence when unbalanced.

init_transport()

Initialise the transport matrix.

init_transport_mass()

Initialise the transport mass.

marginal_dependent_cost(marginal_1, marginal_2)

Initialise cost term that depends on the marginals of the transport.

to_low_rank([seed])

Convert geometries to low-rank.

update_linearization(transport[, epsilon, ...])

Update linearization of GW problem by updating cost matrix.

update_lr_geom(lr_sink)

Recompute (possibly LRC) linearization using LR Sinkhorn output.

update_lr_linearization(lr_sink)

Update a Quad problem linearization using a LR Sinkhorn.

Attributes

a

First marginal.

b

Second marginal.

geom_xx

Geometry of the first space.

geom_xy

Geometry of the joint space.

geom_yy

Geometry of the second space.

is_balanced

Whether the problem is balanced.

is_fused

Whether the problem is fused.

is_low_rank

Whether all geometries are low-rank.

linear_loss

Linear part of the Gromov-Wasserstein loss.

quad_loss

Quadratic part of the Gromov-Wasserstein loss.