ott.core.gromov_wasserstein.gromov_wasserstein#

ott.core.gromov_wasserstein.gromov_wasserstein(geom_xx, geom_yy, geom_xy=None, fused_penalty=1.0, scale_cost=False, a=None, b=None, loss=None, tau_a=1.0, tau_b=1.0, gw_unbalanced_correction=True, **kwargs)[source]#

Solve a Gromov Wasserstein problem.

Wrapper that instantiates a quadratic problem (possibly with linear term if the problem is fused) and calls a solver to output a solution.

Parameters
  • geom_xx (Geometry) – a Geometry object for the first view.

  • geom_yy (Geometry) – a second Geometry object for the second view.

  • geom_xy (Optional[Geometry]) – a Geometry object representing the linear cost in FGW.

  • fused_penalty (float) – multiplier of the linear term in Fused Gromov Wasserstein, i.e. loss = quadratic_loss + fused_penalty * linear_loss. Ignored if geom_xy is not specified.

  • scale_cost (Union[bool, float, str, None]) –

    option to rescale the cost matrices:

  • a (Optional[ndarray]) – jnp.ndarray<float>[num_a,] or jnp.ndarray<float>[batch,num_a] weights.

  • b (Optional[ndarray]) – jnp.ndarray<float>[num_b,] or jnp.ndarray<float>[batch,num_b] weights.

  • loss (Optional[str]) – str, None defaults to the square Euclidean distance, can also receive ‘kl’ to define the GW loss.

  • tau_a (Optional[float]) – float between 0 and 1.0, parameter that controls the strength of the KL divergence constraint between the weights and marginals of the transport for the first view. If set to 1.0, then it is equivalent to a hard constraint and if smaller to a softer constraint.

  • tau_b (Optional[float]) – float between 0 and 1.0, parameter that controls the strength of the KL divergence constraint between the weights and marginals of the transport for the second view. If set to 1.0, then it is equivalent to a hard constraint and if smaller to a softer constraint.

  • gw_unbalanced_correction (bool) – True (default) if the unbalanced version of Sejourne et al (Neurips 2021) is used, False if tau_a and tau_b only affect the inner Sinhkorn loop.

  • kwargs (Any) – keyword arguments to make.

Return type

GWOutput

Returns

A GromovWassersteinState named tuple.