ott.geometry.pointcloud.PointCloud.apply_cost

ott.geometry.pointcloud.PointCloud.apply_cost#

PointCloud.apply_cost(arr, axis=0, fn=None, is_linear=False)[source]#

Apply cost matrix to array (vector or matrix).

This function applies the geometry’s cost matrix, to perform either output = C arr (if axis=1) output = C’ arr (if axis=0) where C is [num_a, num_b] matrix resulting from the (optional) elementwise application of fn to each entry of the cost_matrix.

Parameters:
  • arr (Array) – jnp.ndarray [num_a or num_b, batch], vector that will be multiplied by the cost matrix.

  • axis (int) – standard cost matrix if axis=1, transpose if 0.

  • fn (Optional[Callable[[Array], Array]]) – function optionally applied to cost matrix element-wise, before the apply.

  • is_linear (bool) – Whether fn is a linear function. If true and is_squared_euclidean is True, efficient implementation is used. See ott.geometry.geometry.is_linear() for a heuristic to help determine if a function is linear.

Return type:

Array

Returns:

A jnp.ndarray, [num_b, batch] if axis=0 or [num_a, batch] if axis=1