ott.core.segment.segment_point_cloud(x, a=None, num_segments=None, max_measure_size=None, segment_ids=None, indices_are_sorted=False, num_per_segment=None, padding_vector=None)[source]#

Segment and pad as needed the entries of a point cloud.

There are two interfaces:

  1. use segment_ids, and optionally indices_are_sorted to describe for each data point in the matrix to which segment it belongs to.

  2. use num_per_segment which describes contiguous segments.

If using the 1st interface, num_segments is required for JIT compilation. Assumes range(0, num_segments) are the segment ids.

In both cases, jitting requires defining a max_measure_size, the upper bound on the maximal size of measures, which will be used for padding.

  • x (ndarray) – Array of input points, of shape [num_x, ndim]. Multiple segments are held in this single array.

  • a (Optional[ndarray]) – Array of shape [num_x,] containing the weights (within each measure) of all the points.

  • num_segments (Optional[int]) – Number of segments. Required for jitting. If None and using the 2nd interface, it will be computed as len(num_per_segment).

  • max_measure_size (Optional[int]) – Overall size of padding. Required for jitting. If None and using the 2nd interface, it will be computed as max(num_per_segment).

  • segment_ids (Optional[ndarray]) – 1st interface The segment ids for which each row of x belongs. This is a similar interface to jax.ops.segment_sum().

  • indices_are_sorted (bool) – 1st interface Whether segment_ids are sorted.

  • num_per_segment (Optional[Tuple[int, ...]]) – 2nd interface Number of points in each segment. For example, [100, 20, 30] would imply that x is segmented into 3 arrays of length [100], [20], and [30], respectively. Must be a tuple and not a jax.numpy.ndarray to allow jitting. This means changes in num_per_segment will re-trigger compilation.

  • padding_vector (Optional[ndarray]) – vector to be used to pad point cloud matrices. Most likely to be zero, but can be adjusted to be other values to avoid errors or over/underflow in cost matrix that could be problematic (even these values are not supposed to be taken given their corresponding masses are 0). See also ott.geometry.costs.CostFn.padder(). If None, vector of 0s of shape [1, ndim] is used.

Return type

Tuple[ndarray, ndarray]


Segmented x as an array of shape [num_measures, max_measure_size, ndim] and a as an array of shape [num_measures, max_measure_size].