ott.geometry.grid.Grid.apply_cost#
- Grid.apply_cost(arr, axis=0, fn=None, is_linear=False)#
Apply
cost_matrix
to array (vector or matrix).This function applies the ground geometry’s cost matrix, to perform either output = C arr (if axis=1) output = C’ arr (if axis=0) where C is [num_a, num_b]
- Parameters:
arr (
Array
) – jnp.ndarray [num_a or num_b, p], vector that will be multiplied by the cost matrix.axis (
int
) – standard cost matrix if axis=1, transpose if 0fn (
Optional
[Callable
[[Array
],Array
]]) – function to apply to cost matrix element-wise before the dot productis_linear (
bool
) – Whetherfn
is linear.
- Return type:
- Returns:
An array, [num_b, p] if axis=0 or [num_a, p] if axis=1