ott.core.gromov_wasserstein.gromov_wasserstein#

ott.core.gromov_wasserstein.gromov_wasserstein(geom_xx, geom_yy, geom_xy=None, fused_penalty=1.0, scale_cost=False, a=None, b=None, loss='sqeucl', tau_a=1.0, tau_b=1.0, gw_unbalanced_correction=True, ranks=- 1, tolerances=0.01, **kwargs)[source]#

Solve a Gromov Wasserstein problem.

Wrapper that instantiates a quadratic problem (possibly with linear term if the problem is fused) and calls a solver to output a solution.

Parameters
  • geom_xx (Geometry) – a Geometry object for the first view.

  • geom_yy (Geometry) – a second Geometry object for the second view.

  • geom_xy (Optional[Geometry]) – a Geometry object representing the linear cost in FGW.

  • fused_penalty (float) – multiplier of the linear term in Fused Gromov Wasserstein, i.e. loss = quadratic_loss + fused_penalty * linear_loss. Ignored if geom_xy is not specified.

  • scale_cost (Union[bool, float, str, None]) –

    option to rescale the cost matrices:

  • a (Optional[ndarray]) – jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights.

  • b (Optional[ndarray]) – jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights.

  • loss (Union[Literal[‘sqeucl’, ‘kl’], GWLoss]) – defaults to the square Euclidean distance. Can also pass ‘kl’ to define the GW loss as KL loss. See GromovWasserstein on how to pass custom loss.

  • tau_a (Optional[float]) – float between 0 and 1.0, parameter that controls the strength of the KL divergence constraint between the weights and marginals of the transport for the first view. If set to 1.0, then it is equivalent to a hard constraint and if smaller to a softer constraint.

  • tau_b (Optional[float]) – float between 0 and 1.0, parameter that controls the strength of the KL divergence constraint between the weights and marginals of the transport for the second view. If set to 1.0, then it is equivalent to a hard constraint and if smaller to a softer constraint.

  • gw_unbalanced_correction (bool) – True (default) if the unbalanced version of [Sejourne et al., 2021] is used, False if tau_a and tau_b only affect the inner Sinkhorn loop.

  • ranks (Union[int, Tuple[int, ...]]) – Switch to a low rank approximation of all cost matrices, using to_LRCGeometry(), to gain speed. This is only relevant if the geometries of interest are not PointCloud with ‘sqeucl’ cost function, in which case they would be low-rank by construction (as long as the sizes of these point clouds is larger than dimension). If -1, geometries are left as they are, and not converted. If tuple, these 2 or 3 int specify the ranks of geom_xx, geom_yy and geom_xy, respectively. If int, all 3 geometries are converted using that rank.

  • tolerances (Union[float, Tuple[float, ...]]) – Tolerances used when converting geometries to low-rank. Used when geometries are not PointCloud with ‘sqeucl’ cost. If float, that tolerance is shared across all 3 geometries.

  • kwargs (Any) – Keyword arguments to GromovWasserstein.

Return type

GWOutput

Returns

A GromovWassersteinState named tuple.