# ott.tools.sinkhorn_divergence.segment_sinkhorn_divergence#

ott.tools.sinkhorn_divergence.segment_sinkhorn_divergence(x, y, num_segments=None, max_measure_size=None, cost_fn=None, segment_ids_x=None, segment_ids_y=None, indices_are_sorted=False, num_per_segment_x=None, num_per_segment_y=None, weights_x=None, weights_y=None, sinkhorn_kwargs=mappingproxy({}), static_b=False, share_epsilon=True, symmetric_sinkhorn=False, **kwargs)[source]#

Compute sinkhorn divergence between subsets of vectors given in x & y.

Helper function designed to compute Sinkhorn divergences between several point clouds of varying size, in parallel, using padding for efficiency. In practice, The inputs x and y (and their weight vectors weights_x and weights_y) are assumed to be large weighted point clouds, that describe points taken from multiple measures. To extract several subsets of points, we provide two interfaces. The first interface assumes that a vector of id’s is passed, describing for each point of x (resp. y) to which measure the point belongs to. The second interface assumes that x and y were simply formed by concatenating several measures contiguously, and that only indices that segment these groups are needed to recover them.

For both interfaces, both x and y should contain the same total number of segments. Each segment will be padded as necessary, all segments rearranged as a tensor, and vmap used to evaluate sinkhorn divergences in parallel.

Parameters
• x (`Array`) – Array of input points, of shape [num_x, feature]. Multiple segments are held in this single array.

• y (`Array`) – Array of target points, of shape [num_y, feature].

• num_segments (`Optional`[`int`]) – Number of segments contained in x and y. Providing this is required for JIT compilation to work, see also `segment_point_cloud()`.

• max_measure_size (`Optional`[`int`]) – Total size of measures after padding. Should ideally be set to an upper bound on points clouds processed with the segment interface. Should also be smaller than total length of x or y. Providing this is required for JIT compilation to work.

• cost_fn (`Optional`[`CostFn`]) – Cost function, defaults to `SqEuclidean`.

• segment_ids_x (`Optional`[`Array`]) – 1st interface The segment ID for which each row of x belongs. This is a similar interface to `jax.ops.segment_sum()`.

• segment_ids_y (`Optional`[`Array`]) – 1st interface The segment ID for which each row of y belongs.

• indices_are_sorted (`bool`) – 1st interface Whether segment_ids_x and segment_ids_y are sorted.

• num_per_segment_x (`Optional`[`Tuple`[`int`, `...`]]) – 2nd interface Number of points in each segment in x. For example, [100, 20, 30] would imply that x is segmented into three arrays of length , , and  respectively.

• num_per_segment_y (`Optional`[`Tuple`[`int`, `...`]]) – 2nd interface Number of points in each segment in y.

• weights_x (`Optional`[`Array`]) – Weights of each input points, arranged in the same segmented order as x.

• weights_y (`Optional`[`Array`]) – Weights of each input points, arranged in the same segmented order as y.

• sinkhorn_kwargs (`Mapping`[`str`, `Any`]) – Optionally a dict containing the keywords arguments for calls to the sinkhorn function, called three times to evaluate for each segment the sinkhorn regularized OT cost between x/y, x/x, and y/y (except when static_b is True, in which case y/y is not evaluated)

• static_b (`bool`) – if True, divergence of measure b against itself is NOT computed

• share_epsilon (`bool`) – if True, enforces that the same epsilon regularizer is shared for all 2 or 3 terms of the Sinkhorn divergence. In that case, the epsilon will be by default that used when comparing x to y (contained in the first geometry). This flag is set to True by default, because in the default setting, the epsilon regularization is a function of the mean of the cost matrix.

• symmetric_sinkhorn (`bool`) – Use Sinkhorn updates in Eq. 25 of for symmetric terms comparing x/x and y/y.

• kwargs (`Any`) – keywords arguments passed to form `PointCloud` geometry objects from the subsets of points and masses selected in x and y, this could be for instance entropy regularization float, scheduler or normalization.

Return type

`Array`

Returns

An array of sinkhorn divergence values for each segment.