ott.tools.segment_sinkhorn.segment_sinkhorn#
- ott.tools.segment_sinkhorn.segment_sinkhorn(x, y, num_segments=None, max_measure_size=None, cost_fn=None, segment_ids_x=None, segment_ids_y=None, indices_are_sorted=False, num_per_segment_x=None, num_per_segment_y=None, weights_x=None, weights_y=None, sinkhorn_kwargs=mappingproxy({}), **kwargs)[source]#
Compute regularized OT cost between subsets of vectors in
x
andy
.Helper function designed to compute Sinkhorn regularized OT cost between several point clouds of varying size, in parallel, using padding. In practice, The inputs
x
andy
(and their weight vectors weights_x andweights_y
) are assumed to be large weighted point clouds, that describe points taken from multiple measures. To extract several subsets of points, we provide two interfaces. The first interface assumes that a vector of id’s is passed, describing for each point ofx
(resp.y
) to which measure the point belongs to. The second interface assumes thatx
andy
were simply formed by concatenating several measures contiguously, and that only indices that segment these groups are needed to recover them.For both interfaces, both
x
andy
should contain the same total number of segments. Each segment will be padded as necessary, all segments rearranged as a tensor, andjax.vmap()
used to evaluate Sinkhorn divergences in parallel.- Parameters:
x (
Array
) – Array of input points, of shape[num_x, dim]
. Multiple segments are held in this single array.y (
Array
) – Array of target points, of shape[num_y, dim]
.num_segments (
Optional
[int
]) – Number of segments contained inx``and ``y
. Providing this is required for JIT compilation to work, see alsosegment_point_cloud()
.max_measure_size (
Optional
[int
]) – Total size of measures after padding. Should ideally be set to an upper bound on points clouds processed with the segment interface. Providing this is required for JIT compilation to work.cost_fn (
Optional
[CostFn
]) – Cost function, defaults toSqEuclidean
.segment_ids_x (
Optional
[Array
]) – 1st interface The segment ID for which each row of x belongs. This is a similar interface to jax.ops.segment_sum.segment_ids_y (
Optional
[Array
]) – 1st interface The segment ID for which each row of y belongs.indices_are_sorted (
bool
) – 1st interface Whether segment_ids_x and segment_ids_y are sorted.num_per_segment_x (
Optional
[Tuple
[int
,...
]]) – 2nd interface Number of points in each segment in x. For example,[100, 20, 30]
would imply thatx``is segmented into 3 arrays of length ``[100]
,[20]
, and[30]
respectively.num_per_segment_y (
Optional
[Tuple
[int
,...
]]) – 2nd interface Number of points in each segment iny
.weights_x (
Optional
[Array
]) – Weights of each input points, arranged in the same segmented order asx
.weights_y (
Optional
[Array
]) – Weights of each input points, arranged in the same segmented order asy
.sinkhorn_kwargs (
Mapping
[str
,Any
]) – Optionally a dict containing the keywords arguments for calls for theSinkhorn
solver, called three times to evaluate for each segment the Sinkhorn regularized OT cost betweenx
/y
,x
/x, and y/y
(except whenstatic_b
isTrue
, in which casey
/y
is not evaluated).kwargs (
Any
) – keywords arguments passed to formPointCloud
geometry objects from the subsets of points and masses selected inx``and ``y
, possibly aCostFn
or an entropy regularizer.
- Return type:
- Returns:
An array of Sinkhorn regularized OT costs for each segment.