ott.tools.sinkhorn_divergence.SinkhornDivergenceOutput

ott.tools.sinkhorn_divergence.SinkhornDivergenceOutput#

class ott.tools.sinkhorn_divergence.SinkhornDivergenceOutput(divergence, geoms, a, b, potentials, factors, errors, converged, n_iters)[source]#

Holds the outputs of a call to sinkhorn_divergence().

Objects of this class contain both solutions and problem definition of a two or three regularized OT problem instantiated when computing a Sinkhorn divergence between two probability distributions.

Parameters:
  • divergence (float) – value of the Sinkhorn divergence

  • geoms (Tuple[Geometry, Geometry, Geometry]) – three geometries describing the Sinkhorn divergence, of respective sizes [n, m], [n, n], [m, m] if their cost or kernel matrices where instantiated.

  • a (Array) – first [n,] vector of marginal weights.

  • b (Array) – second [m,] vector of marginal weights.

  • potentials (Optional[Tuple[Tuple[Array, Array], Tuple[Array, Array], Tuple[Array, Array]]]) – three pairs of dual potential vectors, of sizes [n,], [m,], [n,], [n,], [m,], [m,], returned when the call to the solve() solver to compute the divergence relies on a vanilla Sinkhorn solver.

  • factors (Optional[Tuple[Tuple[Array, Array, Array], Tuple[Array, Array, Array], Tuple[Array, Array, Array]]]) – three triplets of matrices, of sizes ([n, rank], [m, rank], [rank,]), ([n, rank], [n, rank], [rank,]) and ([m, rank], [m, rank], [rank,]), returned when the call to the solve() solver to compute the divergence relies on a low-rank LRSinkhorn solver.

  • converged (Tuple[bool, bool, bool]) – triplet of booleans indicating the convergence of each of the three problems run to compute the divergence.

  • n_iters (Tuple[int, int, int]) – number of iterations keeping track of compute effort needed to complete each of the three terms in the divergence.

  • errors (Tuple[Optional[Array], Optional[Array], Optional[Array]])

Methods

to_dual_potentials()

Return dual potential functions, [Pooladian et al., 2022].

Attributes