ott.solvers.utils.sample_conditional

ott.solvers.utils.sample_conditional#

ott.solvers.utils.sample_conditional(rng, tmat, *, k=1)[source]#

Sample conditionally from a transport matrix.

Parameters:
  • rng (Array) – Random number generator.

  • tmat (Array) – Transport matrix of shape [n, m].

  • k (int) – Expected number of samples to sample per source sample.

Return type:

Tuple[Array, Array]

Returns:

Source and target indices of shape [n, k] and [m, k], respectively.