ott.tools.sinkhorn_divergence.SinkhornDivergenceOutput#
- class ott.tools.sinkhorn_divergence.SinkhornDivergenceOutput(divergence, geoms, a, b, potentials, factors, errors, converged, n_iters)[source]#
Holds the outputs of a call to
sinkhorn_divergence().Objects of this class contain both solutions and problem definition of a two or three regularized OT problem instantiated when computing a Sinkhorn divergence between two probability distributions.
- Parameters:
divergence (
float) – value of the Sinkhorn divergencegeoms (
Tuple[Geometry,Geometry,Geometry]) – three geometries describing the Sinkhorn divergence, of respective sizes[n, m], [n, n], [m, m]if their cost or kernel matrices where instantiated.a (
Array) – first[n,]vector of marginal weights.b (
Array) – second[m,]vector of marginal weights.potentials (
Optional[Tuple[Tuple[Array,Array],Tuple[Array,Array],Tuple[Array,Array]]]) – three pairs of dual potential vectors, of sizes[n,], [m,],[n,], [n,],[m,], [m,], returned when the call to thesolve()solver to compute the divergence relies on a vanillaSinkhornsolver.factors (
Optional[Tuple[Tuple[Array,Array,Array],Tuple[Array,Array,Array],Tuple[Array,Array,Array]]]) – three triplets of matrices, of sizes([n, rank], [m, rank], [rank,]),([n, rank], [n, rank], [rank,])and([m, rank], [m, rank], [rank,]), returned when the call to thesolve()solver to compute the divergence relies on a low-rankLRSinkhornsolver.converged (
Tuple[bool,bool,bool]) – triplet of booleans indicating the convergence of each of the three problems run to compute the divergence.n_iters (
Tuple[int,int,int]) – number of iterations keeping track of compute effort needed to complete each of the three terms in the divergence.errors (
Tuple[Optional[Array],Optional[Array],Optional[Array]])
Methods
to_dual_potentials([epsilon])Return dual potential functions [Pooladian et al., 2022].
Attributes
Whether the output is low-rank.