ott.core.bar_problems.BarycenterProblem#

class ott.core.bar_problems.BarycenterProblem(y=None, b=None, weights=None, cost_fn=None, epsilon=None, debiased=False, segment_ids=None, num_segments=None, indices_are_sorted=None, num_per_segment=None, max_measure_size=None)[source]#

Definition of a linear regularized OT problem and some tools.

Parameters
  • y (Optional[ndarray]) – a matrix merging the points of all measures.

  • b (Optional[ndarray]) – a vector containing the weights (within each masure) of all the points

  • weights (Optional[ndarray]) – weights of the barycenter problem (size num_segments)

  • cost_fn (Optional[CostFn]) – cost function used.

  • epsilon (Optional[ndarray]) – epsilon regularization used to solve reg-OT problems.

  • debiased (bool) – whether the problem is debiased, in the sense that the regularized transportation cost of barycenter to itself will be considered when computing gradient. Note that if the debiased option is used, the barycenter size (used in call function) needs to be smaller than the max_measure_size parameter below, for parallelization to operate efficiently.

  • segment_ids (Optional[ndarray]) – describe for each point to which measure it belongs.

  • num_segments (Optional[ndarray]) – total number of measures

  • indices_are_sorted (Optional[bool]) – flag indicating indices in segment_ids are sorted.

  • num_per_segment (Optional[ndarray]) – number of points in each segment, if contiguous.

  • max_measure_size (Optional[int]) – max number of points in each segment (for efficient jit)

Methods

add_slice_for_debiased(y, b)

rtype

Tuple[Optional[ndarray], Optional[ndarray]]

Attributes

flattened_b

rtype

Optional[ndarray]

flattened_y

rtype

Optional[ndarray]

max_measure_size

rtype

int

num_segments

rtype

int

segmented_y_b

rtype

Tuple[Optional[ndarray], Optional[ndarray]]

weights

rtype

ndarray