ott.geometry.geometry.Geometry.apply_cost

ott.geometry.geometry.Geometry.apply_cost#

Geometry.apply_cost(arr, axis=0, fn=None, **kwargs)[source]#

Apply cost_matrix to array (vector or matrix).

This function applies the ground geometry’s cost matrix, to perform either output = C arr (if axis=1) output = C’ arr (if axis=0) where C is [num_a, num_b]

Parameters:
  • arr (Array) – jnp.ndarray [num_a or num_b, p], vector that will be multiplied by the cost matrix.

  • axis (int) – standard cost matrix if axis=1, transpose if 0

  • fn (Optional[Callable[[Array], Array]]) – function to apply to cost matrix element-wise before the dot product

  • kwargs (Any) – Keyword arguments for _apply_cost_to_vec().

Return type:

Array

Returns:

An array, [num_b, p] if axis=0 or [num_a, p] if axis=1