ott.geometry.graph.Graph.apply_transport_from_scalings#
- Graph.apply_transport_from_scalings(u, v, vec, axis=0)#
Apply transport matrix computed from scalings to a (batched) vec.
This approach does not instantiate the transport matrix itself, but relies instead on the apply_kernel function.
- Parameters:
u (
Array
) – jnp.ndarray [num_a,] , scaling of size num_rows of cost_matrixv (
Array
) – jnp.ndarray [num_b,] , scaling of size num_cols of cost_matrixvec (
Array
) – jnp.ndarray [batch, num_a or num_b], vector that will be multiplied by transport matrix corresponding to scalings u, v, and geom.axis (
int
) – axis to differentiate left (0) or right (1) multiply.
- Return type:
- Returns:
ndarray of the size of vec.