ott.tools.sinkhorn_divergence.segment_sinkhorn_divergence#
- ott.tools.sinkhorn_divergence.segment_sinkhorn_divergence(x, y, num_segments=None, max_measure_size=None, cost_fn=None, segment_ids_x=None, segment_ids_y=None, indices_are_sorted=False, num_per_segment_x=None, num_per_segment_y=None, weights_x=None, weights_y=None, solve_kwargs=mappingproxy({}), static_b=False, share_epsilon=True, symmetric_sinkhorn=False, **kwargs)[source]#
Compute Sinkhorn divergence between subsets of vectors in x and y.
Helper function designed to compute Sinkhorn divergences between several point clouds of varying size, in parallel, using padding for efficiency. In practice, The inputs x and y (and their weight vectors weights_x and weights_y) are assumed to be large weighted point clouds, that describe points taken from multiple measures. To extract several subsets of points, we provide two interfaces. The first interface assumes that a vector of id’s is passed, describing for each point of x (resp. y) to which measure the point belongs to. The second interface assumes that x and y were simply formed by concatenating several measures contiguously, and that only indices that segment these groups are needed to recover them.
For both interfaces, both x and y should contain the same total number of segments. Each segment will be padded as necessary, all segments rearranged as a tensor, and vmap used to evaluate Sinkhorn divergences in parallel.
- Parameters:
x (
Array
) – Array of input points, of shape [num_x, feature]. Multiple segments are held in this single array.y (
Array
) – Array of target points, of shape [num_y, feature].num_segments (
Optional
[int
]) – Number of segments contained in x and y. Providing this is required for JIT compilation to work, see alsosegment_point_cloud()
.max_measure_size (
Optional
[int
]) – Total size of measures after padding. Should ideally be set to an upper bound on points clouds processed with the segment interface. Should also be smaller than total length of x or y. Providing this is required for JIT compilation to work.cost_fn (
Optional
[CostFn
]) – Cost function, defaults toSqEuclidean
.segment_ids_x (
Optional
[Array
]) – 1st interface The segment ID for which each row of x belongs. This is a similar interface tojax.ops.segment_sum()
.segment_ids_y (
Optional
[Array
]) – 1st interface The segment ID for which each row of y belongs.indices_are_sorted (
bool
) – 1st interface Whether segment_ids_x and segment_ids_y are sorted.num_per_segment_x (
Optional
[Tuple
[int
,...
]]) – 2nd interface Number of points in each segment in x. For example, [100, 20, 30] would imply that x is segmented into three arrays of length [100], [20], and [30] respectively.num_per_segment_y (
Optional
[Tuple
[int
,...
]]) – 2nd interface Number of points in each segment in y.weights_x (
Optional
[Array
]) – Weights of each input points, arranged in the same segmented order as x.weights_y (
Optional
[Array
]) – Weights of each input points, arranged in the same segmented order as y.solve_kwargs (
Mapping
[str
,Any
]) – Optionally a dict containing the keywords arguments for calls to the sinkhorn function, called three times to evaluate for each segment the Sinkhorn regularized OT cost between x/y, x/x, and y/y (except when static_b is True, in which case y/y is not evaluated)static_b (
bool
) – if True, divergence of measure b against itself is NOT computedshare_epsilon (
bool
) – if True, enforces that the same epsilon regularizer is shared for all 2 or 3 terms of the Sinkhorn divergence. In that case, the epsilon will be by default that used when comparing x to y (contained in the first geometry). This flag is set to True by default, because in the default setting, the epsilon regularization is a function of the mean of the cost matrix.symmetric_sinkhorn (
bool
) – Use Sinkhorn updates in Eq. 25 of [Feydy et al., 2019] for symmetric terms comparing x/x and y/y.kwargs (
Any
) – keywords arguments passed to formPointCloud
geometry objects from the subsets of points and masses selected in x and y, this could be for instance entropy regularization float, scheduler or normalization.
- Return type:
- Returns:
An array of Sinkhorn divergence values for each segment.