ott.geometry.graph.Graph.apply_transport_from_scalings

ott.geometry.graph.Graph.apply_transport_from_scalings#

Graph.apply_transport_from_scalings(u, v, vec, axis=0)#

Apply transport matrix computed from scalings to a (batched) vec.

This approach does not instantiate the transport matrix itself, but relies instead on the apply_kernel function.

Parameters:
  • u (Array) – jnp.ndarray [num_a,] , scaling of size num_rows of cost_matrix

  • v (Array) – jnp.ndarray [num_b,] , scaling of size num_cols of cost_matrix

  • vec (Array) – jnp.ndarray [batch, num_a or num_b], vector that will be multiplied by transport matrix corresponding to scalings u, v, and geom.

  • axis (int) – axis to differentiate left (0) or right (1) multiply.

Return type:

Array

Returns:

ndarray of the size of vec.