arr (Array
) – jnp.ndarray [num_a or num_b, p], vector that will be multiplied
by the cost matrix.
axis (int
) – standard cost matrix if axis=1, transport if 0.
fn (Optional
[Callable
[[Array
], Array
]]) – function optionally applied to cost matrix element-wise, before the
application.