ott.tools.sinkhorn_divergence.segment_sinkhorn_divergence#

ott.tools.sinkhorn_divergence.segment_sinkhorn_divergence(x, y, num_segments=None, max_measure_size=None, cost_fn=None, segment_ids_x=None, segment_ids_y=None, indices_are_sorted=None, num_per_segment_x=None, num_per_segment_y=None, weights_x=None, weights_y=None, sinkhorn_kwargs=mappingproxy({}), static_b=False, share_epsilon=True, **kwargs)[source]#

Compute sinkhorn divergence between subsets of vectors given in x & y.

Helper function designed to compute Sinkhorn divergences between several point clouds of varying size, in parallel, using padding for efficiency. In practice, The inputs x and y (and their weight vectors weights_x and weights_y) are assumed to be large weighted point clouds, that describe points taken from multiple measures. To extract several subsets of points, we provide two interfaces. The first interface assumes that a vector of id’s is passed, describing for each point of x (resp. y) to which measure the point belongs to. The second interface assumes that x and y were simply formed by concatenating several measures contiguously, and that only indices that segment these groups are needed to recover them.

For both interfaces, both x and y should contain the same total number of segments. Each segment will be padded as necessary, all segments rearranged as a tensor, and vmap used to evaluate sinkhorn divergences in parallel.

Parameters
  • x (ndarray) – Array of input points, of shape [num_x, feature]. Multiple segments are held in this single array.

  • y (ndarray) – Array of target points, of shape [num_y, feature].

  • num_segments (Optional[int]) – Number of segments contained in x and y. Providing this number is required for JIT compilation to work, see also segment_point_cloud().

  • max_measure_size (Optional[int]) – Total size of measures after padding. Should ideally be set to an upper bound on points clouds processed with the segment interface. Should also be smaller than total length of x or y. Providing this number is required for JIT compilation to work.

  • cost_fn (Optional[CostFn]) – Cost function, defaults to Euclidean.

  • segment_ids_x (Optional[ndarray]) – 1st interface The segment ID for which each row of x belongs. This is a similar interface to jax.ops.segment_sum().

  • segment_ids_y (Optional[ndarray]) – 1st interface The segment ID for which each row of y belongs.

  • indices_are_sorted (Optional[bool]) – 1st interface Whether segment_ids_x and segment_ids_y are sorted. Default false.

  • num_per_segment_x (Optional[ndarray]) – 2nd interface Number of points in each segment in x. For example, [100, 20, 30] would imply that x is segmented into three arrays of length [100], [20], and [30] respectively.

  • num_per_segment_y (Optional[ndarray]) – 2nd interface Number of points in each segment in y.

  • weights_x (Optional[ndarray]) – Weights of each input points, arranged in the same segmented order as x.

  • weights_y (Optional[ndarray]) – Weights of each input points, arranged in the same segmented order as y.

  • sinkhorn_kwargs (Mapping[str, Any]) – Optionally a dict containing the keywords arguments for calls to the sinkhorn function, called three times to evaluate for each segment the sinkhorn regularized OT cost between x/y, x/x, and y/y (except when static_b is True, in which case y/y is not evaluated)

  • static_b (bool) – if True, divergence of measure b against itself is NOT computed

  • share_epsilon (bool) – if True, enforces that the same epsilon regularizer is shared for all 2 or 3 terms of the Sinkhorn divergence. In that case, the epsilon will be by default that used when comparing x to y (contained in the first geometry). This flag is set to True by default, because in the default setting, the epsilon regularization is a function of the mean of the cost matrix.

  • kwargs (Any) – keywords arguments passed to form ott.geometry.pointcloud.PointCloud geometry objects from the subsets of points and masses selected in x and y, this could be for instance entropy regularization float, scheduler or normalization.

Return type

ndarray

Returns

An array of sinkhorn divergence values for each segment.