ott.geometry.grid.Grid.apply_transport_from_potentials

ott.geometry.grid.Grid.apply_transport_from_potentials#

Grid.apply_transport_from_potentials(f, g, vec, axis=0)#

Apply transport matrix computed from potentials to a (batched) vec.

This approach does not instantiate the transport matrix itself, but uses instead potentials to apply the transport using apply_lse_kernel, therefore guaranteeing stability and lower memory footprint.

Computations are done in log space, and take advantage of the (b=…, return_sign=True) optional parameters of logsumexp.

Parameters:
  • f (Array) – jnp.ndarray [num_a,] , potential of size num_rows of cost_matrix

  • g (Array) – jnp.ndarray [num_b,] , potential of size num_cols of cost_matrix

  • vec (Array) – jnp.ndarray [batch, num_a or num_b], vector that will be multiplied by transport matrix corresponding to potentials f, g, and geom.

  • axis (int) – axis to differentiate left (0) or right (1) multiply.

Return type:

Array

Returns:

ndarray of the size of vec.