ott.geometry.pointcloud.PointCloud.vec_apply_cost

ott.geometry.pointcloud.PointCloud.vec_apply_cost#

PointCloud.vec_apply_cost(arr, axis=0, fn=None)[source]#

Apply the geometry’s cost matrix in a vectorized way.

This function can be used when the cost matrix is squared euclidean and fn is a linear function.

Parameters:
  • arr (Array) – jnp.ndarray [num_a or num_b, p], vector that will be multiplied by the cost matrix.

  • axis (int) – standard cost matrix if axis=1, transport if 0.

  • fn (Optional[Callable[[Array], Array]]) – function optionally applied to cost matrix element-wise, before the application.

Return type:

Array

Returns:

A jnp.ndarray, [num_b, p] if axis=0 or [num_a, p] if axis=1