ott.tools.sinkhorn_divergence.SinkhornDivergenceOutput#
- class ott.tools.sinkhorn_divergence.SinkhornDivergenceOutput(divergence, geoms, a, b, potentials, factors, errors, converged, n_iters)[source]#
Holds the outputs of a call to
sinkhorn_divergence()
.Objects of this class contain both solutions and problem definition of a two or three regularized OT problem instantiated when computing a Sinkhorn divergence between two probability distributions.
- Parameters:
divergence (
float
) – value of the Sinkhorn divergencegeoms (
Tuple
[Geometry
,Geometry
,Geometry
]) – three geometries describing the Sinkhorn divergence, of respective sizes[n, m], [n, n], [m, m]
if their cost or kernel matrices where instantiated.a (
Array
) – first[n,]
vector of marginal weights.b (
Array
) – second[m,]
vector of marginal weights.potentials (
Optional
[Tuple
[Tuple
[Array
,Array
],Tuple
[Array
,Array
],Tuple
[Array
,Array
]]]) – three pairs of dual potential vectors, of sizes[n,], [m,]
,[n,], [n,]
,[m,], [m,]
, returned when the call to thesolve()
solver to compute the divergence relies on a vanillaSinkhorn
solver.factors (
Optional
[Tuple
[Tuple
[Array
,Array
,Array
],Tuple
[Array
,Array
,Array
],Tuple
[Array
,Array
,Array
]]]) – three triplets of matrices, of sizes([n, rank], [m, rank], [rank,])
,([n, rank], [n, rank], [rank,])
and([m, rank], [m, rank], [rank,])
, returned when the call to thesolve()
solver to compute the divergence relies on a low-rankLRSinkhorn
solver.converged (
Tuple
[bool
,bool
,bool
]) – triplet of booleans indicating the convergence of each of the three problems run to compute the divergence.n_iters (
Tuple
[int
,int
,int
]) – number of iterations keeping track of compute effort needed to complete each of the three terms in the divergence.errors (
Tuple
[Optional
[Array
],Optional
[Array
],Optional
[Array
]])
Methods
Return dual potential functions, [Pooladian et al., 2022].
Attributes
Whether the output is low-rank.