ott.problems.quadratic.quadratic_problem.QuadraticProblem

Contents

ott.problems.quadratic.quadratic_problem.QuadraticProblem#

class ott.problems.quadratic.quadratic_problem.QuadraticProblem(geom_xx, geom_yy, geom_xy=None, fused_penalty=1.0, scale_cost=None, a=None, b=None, loss='sqeucl', tau_a=1.0, tau_b=1.0, gw_unbalanced_correction=True, ranks=-1, tolerances=0.01)[source]#

Quadratic OT problem.

The quadratic loss of a single OT matrix is assumed to have the form given in [Peyré et al., 2016], eq. 4.

The two geometries below parameterize matrices \(C\) and \(\bar{C}\) in that equation. The function \(L\) (of two real values) in that equation is assumed to match the form given in eq. 5., with our notations:

\[L(x, y) = f_1(x) + f_2(y) - h_1(x) h_2(y)\]
Parameters:
  • geom_xx (Geometry) – Ground geometry of the first space.

  • geom_yy (Geometry) – Ground geometry of the second space.

  • geom_xy (Optional[Geometry]) – Geometry defining the linear penalty term for fused Gromov-Wasserstein [Titouan et al., 2019]. If None, the problem reduces to a plain Gromov-Wasserstein problem [Peyré et al., 2016].

  • fused_penalty (float) – Multiplier of the linear term in fused Gromov-Wasserstein, i.e. problem = purely quadratic + fused_penalty * linear problem.

  • scale_cost (Union[float, str, None]) – How to rescale the cost matrices. If a str, use specific options available in Geometry or PointCloud. If None, keep the original scaling.

  • a (Optional[Array]) – The first marginal. If None, it will be uniform.

  • b (Optional[Array]) – The second marginal. If None, it will be uniform.

  • loss (Union[Literal['sqeucl', 'kl'], GWLoss]) – Gromov-Wasserstein loss function, see GWLoss for more information.

  • tau_a (float) – If \(< 1.0\), defines how much unbalanced the problem is on the first marginal.

  • tau_b (float) – If \(< 1.0\), defines how much unbalanced the problem is on the second marginal.

  • gw_unbalanced_correction (bool) – Whether the unbalanced version of [Sejourne et al., 2021] is used. Otherwise, tau_a and tau_b only affect the inner Sinkhorn loop.

  • ranks (Union[int, Tuple[int, ...]]) – Ranks of the cost matrices, see to_LRCGeometry(). Used when geometries are not PointCloud with ‘sqeucl’ cost function. If -1, the geometries will not be converted to low-rank. If tuple, it specifies the ranks of geom_xx, geom_yy and geom_xy, respectively. If int, rank is shared across all geometries.

  • tolerances (Union[float, Tuple[float, ...]]) – Tolerances used when converting geometries to low-rank. Used when geometries are not PointCloud with ‘sqeucl’ cost. If float, it is shared across all geometries.

Methods

cost_unbalanced_correction(transport_matrix, ...)

Calculate cost term from the quadratic divergence when unbalanced.

init_transport_mass()

Initialize the transport mass.

marginal_dependent_cost(marginal_1, marginal_2)

Initialize cost term that depends on the marginals of the transport.

to_low_rank([rng])

Convert geometries to low-rank.

update_linearization(transport[, epsilon, ...])

Update linearization of GW problem by updating cost matrix.

update_lr_geom(lr_sink[, relative_epsilon])

Recompute (possibly LRC) linearization using LR Sinkhorn output.

update_lr_linearization(lr_sink, *[, ...])

Update a Quad problem linearization using a LR Sinkhorn.

Attributes

a

First marginal.

b

Second marginal.

geom_xx

Geometry of the first space.

geom_xy

Geometry of the joint space.

geom_yy

Geometry of the second space.

is_balanced

Whether the problem is balanced.

is_fused

Whether the problem is fused.

is_low_rank

Whether all geometries are low-rank.

linear_loss

Linear part of the Gromov-Wasserstein loss.

quad_loss

Quadratic part of the Gromov-Wasserstein loss.