ott.geometry.segment.segment_point_cloud(x, a=None, num_segments=None, max_measure_size=None, segment_ids=None, indices_are_sorted=False, num_per_segment=None, padding_vector=None)[source]#

Segment and pad as needed the entries of a point cloud.

There are two interfaces:

  1. use segment_ids, and optionally indices_are_sorted to describe for each data point in the matrix to which segment it belongs to.

  2. use num_per_segment which describes contiguous segments.

If using the first interface, num_segments is required for jitting. Assumes range(0, num_segments) are the segment ids.

In both cases, jitting requires defining a max_measure_size, the upper bound on the maximal size of measures, which will be used for padding.

  • x (Array) – Array of input points, of shape [num_x, ndim]. Multiple segments are held in this single array.

  • a (Optional[Array]) – Array of shape [num_x,] containing the weights (within each measure) of all the points.

  • num_segments (Optional[int]) – Number of segments. Required for jitting. If None and using the second interface, it will be computed as len(num_per_segment).

  • max_measure_size (Optional[int]) – Overall size of padding. Required for jitting. If None and using the second interface, it will be computed as max(num_per_segment).

  • segment_ids (Optional[Array]) – 1st interface The segment ids for which each row of x belongs. This is a similar interface to jax.ops.segment_sum().

  • indices_are_sorted (bool) – 1st interface Whether segment_ids are sorted.

  • num_per_segment (Optional[Tuple[int, ...]]) – 2nd interface Number of points in each segment. For example, [100, 20, 30] would imply that x is segmented into 3 arrays of length [100], [20], and [30], respectively. Must be a tuple and not a jax.numpy.ndarray to allow jitting. This means changes in num_per_segment will re-trigger compilation.

  • padding_vector (Optional[Array]) – vector to be used to pad point cloud matrices. Most likely to be zero, but can be adjusted to be other values to avoid errors or over/underflow in cost matrix that could be problematic (even these values are not supposed to be taken given their corresponding masses are 0). See also _padder(). If None, vector of 0s of shape [1, ndim] is used.

Return type:

Tuple[Array, Array]


Segmented x as an array of shape [num_measures, max_measure_size, ndim] and a as an array of shape [num_measures, max_measure_size].