ott.tools.sliced.sliced_wasserstein#
- ott.tools.sliced.sliced_wasserstein(x, y, a=None, b=None, cost_fn=None, proj_fn=None, weights=None, return_transport=False, return_dual_variables=False, **kwargs)[source]#
Compute the Sliced Wasserstein distance between two weighted point clouds.
Follows the approach outlined in [Rabin et al., 2012] to compute a proxy for OT distances that relies on creating features (possibly randomly) for data, through e.g., projections, and then sum the 1D Wasserstein distances between these features’ univariate distributions on both source and target samples.
- Parameters:
x (
Array) – Array of shape[n, dim]of source points’ coordinates.y (
Array) – Array of shape[m, dim]of target points’ coordinates.a (
Optional[Array]) – Array of shape[n,]of source probability weights.b (
Optional[Array]) – Array of shape[m,]of target probability weights.cost_fn (
Optional[CostFn]) – Cost function. Must be a submodular function of two real arguments, i.e. such that \(\partial c(x,y)/\partial x \partial y <0\). IfNone, useSqEuclidean.proj_fn (
Optional[Callable[[Array,int,Array],Array]]) – Projection function, mapping any[b, dim]matrix of coordinates to[b, n_proj]matrix of features, on which 1D transports (forn_projdirections) are subsequently computed independently. By default, userandom_proj_sphere().weights (
Optional[Array]) – Array of shape[n_proj,]of weights used to average then_proj1D Wasserstein contributions (one for each feature) and form the sliced Wasserstein distance. Uniform by default, resulting in average of all these values.return_transport (
bool) – Whether to storen_projtransport plans in the output.return_dual_variables (
bool) – Whether to storen_projpairs of dual vectors in the output.kwargs (
Any) – Keyword arguments toproj_fn. Could for instance include, as done with default projector, number ofn_projprojections, as well as arngkey to sample as many directions.
- Return type:
- Returns:
The sliced Wasserstein distance with the corresponding output object.