ott.solvers.quadratic.solve

Contents

ott.solvers.quadratic.solve#

ott.solvers.quadratic.solve(geom_xx, geom_yy, geom_xy=None, fused_penalty=1.0, a=None, b=None, tau_a=1.0, tau_b=1.0, loss='sqeucl', gw_unbalanced_correction=True, rank=-1, **kwargs)[source]#

Solve quadratic regularized OT problem using a Gromov-Wasserstein solver.

Parameters:
  • geom_xx (Geometry) – Ground geometry of the first space.

  • geom_yy (Geometry) – Ground geometry of the second space.

  • geom_xy (Optional[Geometry]) – Geometry defining the linear penalty term for fused Gromov-Wasserstein [Titouan et al., 2019]. If None, the problem reduces to a plain Gromov-Wasserstein problem [Peyré et al., 2016].

  • fused_penalty (float) – Multiplier of the linear term in fused Gromov-Wasserstein, i.e. problem = purely quadratic + fused_penalty * linear problem.

  • a (Optional[Array]) – The first marginal. If None, it will be uniform.

  • b (Optional[Array]) – The second marginal. If None, it will be uniform.

  • tau_a (float) – If \(< 1\), defines how much unbalanced the problem is on the first marginal.

  • tau_b (float) – If \(< 1\), defines how much unbalanced the problem is on the second marginal.

  • loss (Union[Literal['sqeucl', 'kl'], GWLoss]) – Gromov-Wasserstein loss function, see GWLoss for more information. If rank > 0, 'sqeucl' is always used.

  • gw_unbalanced_correction (bool) – Whether the unbalanced version of [Sejourne et al., 2021] is used. Otherwise, tau_a and tau_b only affect the resolution of the linearization of the GW problem in the inner loop. Only used when rank = -1.

  • rank (int) – Rank constraint on the coupling to minimize the quadratic OT problem [Scetbon et al., 2022]. If \(-1\), no rank constraint is used.

  • kwargs (Any) – Keyword arguments for GromovWasserstein or LRGromovWasserstein, depending on the rank

Return type:

Union[GWOutput, LRGWOutput]

Returns:

The Gromov-Wasserstein output.