ott.tools.segment_sinkhorn.segment_sinkhorn

ott.tools.segment_sinkhorn.segment_sinkhorn#

ott.tools.segment_sinkhorn.segment_sinkhorn(x, y, num_segments=None, max_measure_size=None, cost_fn=None, segment_ids_x=None, segment_ids_y=None, indices_are_sorted=False, num_per_segment_x=None, num_per_segment_y=None, weights_x=None, weights_y=None, sinkhorn_kwargs=mappingproxy({}), **kwargs)[source]#

Compute regularized OT cost between subsets of vectors in x and y.

Helper function designed to compute Sinkhorn regularized OT cost between several point clouds of varying size, in parallel, using padding. In practice, The inputs x and y (and their weight vectors weights_x and weights_y) are assumed to be large weighted point clouds, that describe points taken from multiple measures. To extract several subsets of points, we provide two interfaces. The first interface assumes that a vector of id’s is passed, describing for each point of x (resp. y) to which measure the point belongs to. The second interface assumes that x and y were simply formed by concatenating several measures contiguously, and that only indices that segment these groups are needed to recover them.

For both interfaces, both x and y should contain the same total number of segments. Each segment will be padded as necessary, all segments rearranged as a tensor, and jax.vmap() used to evaluate Sinkhorn divergences in parallel.

Parameters:
  • x (Array) – Array of input points, of shape [num_x, feature]. Multiple segments are held in this single array.

  • y (Array) – Array of target points, of shape [num_y, feature].

  • num_segments (Optional[int]) – Number of segments contained in x and y. Providing this is required for JIT compilation to work, see also segment_point_cloud().

  • max_measure_size (Optional[int]) – Total size of measures after padding. Should ideally be set to an upper bound on points clouds processed with the segment interface. Providing this is required for JIT compilation to work.

  • cost_fn (Optional[CostFn]) – Cost function, defaults to SqEuclidean.

  • segment_ids_x (Optional[Array]) – 1st interface The segment ID for which each row of x belongs. This is a similar interface to jax.ops.segment_sum.

  • segment_ids_y (Optional[Array]) – 1st interface The segment ID for which each row of y belongs.

  • indices_are_sorted (bool) – 1st interface Whether segment_ids_x and segment_ids_y are sorted.

  • num_per_segment_x (Optional[Tuple[int, ...]]) – 2nd interface Number of points in each segment in x. For example, [100, 20, 30] would imply that x is segmented into three arrays of length [100], [20], and [30] respectively.

  • num_per_segment_y (Optional[Tuple[int, ...]]) – 2nd interface Number of points in each segment in y.

  • weights_x (Optional[Array]) – Weights of each input points, arranged in the same segmented order as x.

  • weights_y (Optional[Array]) – Weights of each input points, arranged in the same segmented order as y.

  • sinkhorn_kwargs (Mapping[str, Any]) – Optionally a dict containing the keywords arguments for calls for the Sinkhorn solver, called three times to evaluate for each segment the Sinkhorn regularized OT cost between x/y, x/x, and y/y (except when static_b is True, in which case y/y is not evaluated).

  • kwargs (Any) – keywords arguments passed to form PointCloud geometry objects from the subsets of points and masses selected in x and y, possibly a CostFn or an entropy regularizer.

Return type:

Array

Returns:

An array of Sinkhorn regularized OT costs for each segment.