ott.tools.sliced.sliced_wasserstein#
- ott.tools.sliced.sliced_wasserstein(x, y, a=None, b=None, cost_fn=None, proj_fn=None, weights=None, return_transport=False, return_dual_variables=False, **kwargs)[source]#
Compute the Sliced Wasserstein distance between two weighted point clouds.
Follows the approach outlined in [Rabin et al., 2012] to compute a proxy for OT distances that relies on creating features (possibly randomly) for data, through e.g., projections, and then sum the 1D Wasserstein distances between these features’ univariate distributions on both source and target samples.
- Parameters:
x (
Array
) – Array of shape[n, dim]
of source points’ coordinates.y (
Array
) – Array of shape[m, dim]
of target points’ coordinates.a (
Optional
[Array
]) – Array of shape[n,]
of source probability weights.b (
Optional
[Array
]) – Array of shape[m,]
of target probability weights.cost_fn (
Optional
[CostFn
]) – Cost function. Must be a submodular function of two real arguments, i.e. such that \(\partial c(x,y)/\partial x \partial y <0\). IfNone
, useSqEuclidean
.proj_fn (
Optional
[Callable
[[Array
,int
,Array
],Array
]]) – Projection function, mapping any[b, dim]
matrix of coordinates to[b, n_proj]
matrix of features, on which 1D transports (forn_proj
directions) are subsequently computed independently. By default, userandom_proj_sphere()
.weights (
Optional
[Array
]) – Array of shape[n_proj,]
of weights used to average then_proj
1D Wasserstein contributions (one for each feature) and form the sliced Wasserstein distance. Uniform by default, resulting in average of all these values.return_transport (
bool
) – Whether to storen_proj
transport plans in the output.return_dual_variables (
bool
) – Whether to storen_proj
pairs of dual vectors in the output.kwargs (
Any
) – Keyword arguments toproj_fn
. Could for instance include, as done with default projector, number ofn_proj
projections, as well as arng
key to sample as many directions.
- Return type:
- Returns:
The sliced Wasserstein distance with the corresponding output object.