ott.solvers.quadratic.gromov_wasserstein.solve#

ott.solvers.quadratic.gromov_wasserstein.solve(geom_xx, geom_yy, geom_xy=None, fused_penalty=1.0, scale_cost=False, a=None, b=None, loss='sqeucl', tau_a=1.0, tau_b=1.0, gw_unbalanced_correction=True, ranks=-1, tolerances=0.01, **kwargs)[source]#

Solve quadratic regularized OT problem.

The quadratic loss of a single OT matrix is assumed to have the form given in [Peyré et al., 2016], eq. 4.

The two geometries below parameterize matrices \(C\) and \(\bar{C}\) in that equation. The function \(L\) (of two real values) in that equation is assumed to match the form given in eq. 5., with our notations:

\[L(x, y) = lin1(x) + lin2(y) - quad1(x) * quad2(y)\]
Parameters:
  • geom_xx (Geometry) – Ground geometry of the first space.

  • geom_yy (Geometry) – Ground geometry of the second space.

  • geom_xy (Optional[Geometry]) – Geometry defining the linear penalty term for Fused Gromov-Wasserstein. If None, the problem reduces to a plain Gromov-Wasserstein problem.

  • fused_penalty (float) – multiplier of the linear term in Fused Gromov-Wasserstein, i.e. problem = purely quadratic + fused_penalty * linear problem. Ignored if geom_xy is not specified.

  • scale_cost (Union[bool, float, str, None]) –

    option to rescale the cost matrices:

    • if True, use the default for each geometry.

    • if False, keep the original scaling in geometries.

    • if str, use a specific method available in Geometry or PointCloud.

    • if None, do not scale the cost matrices.

  • a (Optional[Array]) – array representing the probability weights of the samples from geom_xx. If None, it will be uniform.

  • b (Optional[Array]) – array representing the probability weights of the samples from geom_yy. If None, it will be uniform.

  • loss (Union[Literal['sqeucl', 'kl'], GWLoss]) – a 2-tuple of 2-tuples of Callable. The first tuple is the linear part of the loss. The second one is the quadratic part (quad1, quad2). By default, the loss is set as the 4 functions representing the squared Euclidean loss, and this property is taken advantage of in subsequent computations. Alternatively, KL loss can be specified in no less optimized way.

  • tau_a (Optional[float]) – if < 1.0, defines how much unbalanced the problem is on the first marginal.

  • tau_b (Optional[float]) – if < 1.0, defines how much unbalanced the problem is on the second marginal.

  • gw_unbalanced_correction (bool) – Whether the unbalanced version of [Sejourne et al., 2021] is used. Otherwise, tau_a and tau_b only affect the inner Sinkhorn loop.

  • ranks (Union[int, Tuple[int, ...]]) – Ranks of the cost matrices, see to_LRCGeometry(). Used when geometries are not PointCloud with ‘sqeucl’ cost function. If -1, the geometries will not be converted to low-rank. If tuple, it specifies the ranks of geom_xx, geom_yy and geom_xy, respectively. If int, rank is shared across all geometries.

  • tolerances (Union[float, Tuple[float, ...]]) – Tolerances used when converting geometries to low-rank. Used when geometries are not PointCloud with ‘sqeucl’ cost. If float, it is shared across all geometries.

  • kwargs (Any) – Keyword arguments for GromovWasserstein.

Return type:

GWOutput

Returns:

Gromov-Wasserstein output.