Geometry.apply_transport_from_potentials(f, g, vec, axis=0)[source]#

Apply transport matrix computed from potentials to a (batched) vec.

This approach does not instantiate the transport matrix itself, but uses instead potentials to apply the transport using apply_lse_kernel, therefore guaranteeing stability and lower memory footprint.

Computations are done in log space, and take advantage of the (b=…, return_sign=True) optional parameters of logsumexp.

  • f (Array) – jnp.ndarray [num_a,] , potential of size num_rows of cost_matrix

  • g (Array) – jnp.ndarray [num_b,] , potential of size num_cols of cost_matrix

  • vec (Array) – jnp.ndarray [batch, num_a or num_b], vector that will be multiplied by transport matrix corresponding to potentials f, g, and geom.

  • axis (int) – axis to differentiate left (0) or right (1) multiply.

Return type



ndarray of the size of vec.