ott.tools.sliced.sliced_wasserstein

ott.tools.sliced.sliced_wasserstein#

ott.tools.sliced.sliced_wasserstein(x, y, a=None, b=None, cost_fn=None, proj_fn=None, weights=None, return_transport=False, return_dual_variables=False, **kwargs)[source]#

Compute the Sliced Wasserstein distance between two weighted point clouds.

Follows the approach outlined in [Rabin et al., 2012] to compute a proxy for OT distances that relies on creating features (possibly randomly) for data, through e.g., projections, and then sum the 1D Wasserstein distances between these features’ univariate distributions on both source and target samples.

Parameters:
  • x (Array) – Array of shape [n, dim] of source points’ coordinates.

  • y (Array) – Array of shape [m, dim] of target points’ coordinates.

  • a (Optional[Array]) – Array of shape [n,] of source probability weights.

  • b (Optional[Array]) – Array of shape [m,] of target probability weights.

  • cost_fn (Optional[CostFn]) – Cost function. Must be a submodular function of two real arguments, i.e. such that \(\partial c(x,y)/\partial x \partial y <0\). If None, use SqEuclidean.

  • proj_fn (Optional[Callable[[Array, int, Array], Array]]) – Projection function, mapping any [b, dim] matrix of coordinates to [b, n_proj] matrix of features, on which 1D transports (for n_proj directions) are subsequently computed independently. By default, use random_proj_sphere().

  • weights (Optional[Array]) – Array of shape [n_proj,] of weights used to average the n_proj 1D Wasserstein contributions (one for each feature) and form the sliced Wasserstein distance. Uniform by default, resulting in average of all these values.

  • return_transport (bool) – Whether to store n_proj transport plans in the output.

  • return_dual_variables (bool) – Whether to store n_proj pairs of dual vectors in the output.

  • kwargs (Any) – Keyword arguments to proj_fn. Could for instance include, as done with default projector, number of n_proj projections, as well as a rng key to sample as many directions.

Return type:

Tuple[Array, UnivariateOutput]

Returns:

The sliced Wasserstein distance with the corresponding output object.