ott.tools.sinkhorn_divergence.sinkdiv

Contents

ott.tools.sinkhorn_divergence.sinkdiv#

ott.tools.sinkhorn_divergence.sinkdiv(x, y, *, cost_fn=None, epsilon=None, **kwargs)[source]#

Wrapper to get the Sinkhorn divergence between two point clouds.

Convenience wrapper around sinkhorn_divergence() provided to compute the Sinkhorn divergence between two point clouds compared with any ground cost CostFn. See other relevant arguments in sinkhorn_divergence().

Parameters:
  • x (Array) – Array of input points, of shape [num_x, feature].

  • y (Array) – Array of target points, of shape [num_y, feature].

  • cost_fn (Optional[CostFn]) – cost function of interest.

  • epsilon (Optional[float]) – entropic regularization.

  • kwargs (Any) – keywords arguments passed on to the generic sinkhorn_divergence() method. Of notable interest are a and b weight vectors, static_b and offset_static_b which can be used to bypass the computations of the transport problem between points stored in y (possibly with weights b) and themselves, and solve_kwargs to parameterize the linear OT solver.

Return type:

Tuple[Array, SinkhornDivergenceOutput]

Returns:

The Sinkhorn divergence value, and output object detailing computations.