# ott.tools.sinkhorn_divergence.segment_sinkhorn_divergence#

ott.tools.sinkhorn_divergence.segment_sinkhorn_divergence(x, y, num_segments=None, max_measure_size=None, cost_fn=None, segment_ids_x=None, segment_ids_y=None, indices_are_sorted=False, num_per_segment_x=None, num_per_segment_y=None, weights_x=None, weights_y=None, sinkhorn_kwargs=mappingproxy({}), static_b=False, share_epsilon=True, symmetric_sinkhorn=False, **kwargs)[source]#

Compute Sinkhorn divergence between subsets of vectors in x and y.

Helper function designed to compute Sinkhorn divergences between several point clouds of varying size, in parallel, using padding for efficiency. In practice, The inputs x and y (and their weight vectors weights_x and weights_y) are assumed to be large weighted point clouds, that describe points taken from multiple measures. To extract several subsets of points, we provide two interfaces. The first interface assumes that a vector of id’s is passed, describing for each point of x (resp. y) to which measure the point belongs to. The second interface assumes that x and y were simply formed by concatenating several measures contiguously, and that only indices that segment these groups are needed to recover them.

For both interfaces, both x and y should contain the same total number of segments. Each segment will be padded as necessary, all segments rearranged as a tensor, and vmap used to evaluate Sinkhorn divergences in parallel.

Parameters:
• x (`Array`) – Array of input points, of shape [num_x, feature]. Multiple segments are held in this single array.

• y (`Array`) – Array of target points, of shape [num_y, feature].

• num_segments () – Number of segments contained in x and y. Providing this is required for JIT compilation to work, see also `segment_point_cloud()`.

• max_measure_size () – Total size of measures after padding. Should ideally be set to an upper bound on points clouds processed with the segment interface. Should also be smaller than total length of x or y. Providing this is required for JIT compilation to work.

• cost_fn () – Cost function, defaults to `SqEuclidean`.

• segment_ids_x () – 1st interface The segment ID for which each row of x belongs. This is a similar interface to `jax.ops.segment_sum()`.

• segment_ids_y () – 1st interface The segment ID for which each row of y belongs.

• indices_are_sorted (`bool`) – 1st interface Whether segment_ids_x and segment_ids_y are sorted.

• num_per_segment_x () – 2nd interface Number of points in each segment in x. For example, [100, 20, 30] would imply that x is segmented into three arrays of length [100], [20], and [30] respectively.

• num_per_segment_y () – 2nd interface Number of points in each segment in y.

• weights_x () – Weights of each input points, arranged in the same segmented order as x.

• weights_y () – Weights of each input points, arranged in the same segmented order as y.

• sinkhorn_kwargs () – Optionally a dict containing the keywords arguments for calls to the sinkhorn function, called three times to evaluate for each segment the Sinkhorn regularized OT cost between x/y, x/x, and y/y (except when static_b is True, in which case y/y is not evaluated)

• static_b (`bool`) – if True, divergence of measure b against itself is NOT computed

• share_epsilon (`bool`) – if True, enforces that the same epsilon regularizer is shared for all 2 or 3 terms of the Sinkhorn divergence. In that case, the epsilon will be by default that used when comparing x to y (contained in the first geometry). This flag is set to True by default, because in the default setting, the epsilon regularization is a function of the mean of the cost matrix.

• symmetric_sinkhorn (`bool`) – Use Sinkhorn updates in Eq. 25 of for symmetric terms comparing x/x and y/y.

• kwargs (`Any`) – keywords arguments passed to form `PointCloud` geometry objects from the subsets of points and masses selected in x and y, this could be for instance entropy regularization float, scheduler or normalization.

Return type:

`Array`

Returns:

An array of Sinkhorn divergences for each segment.