ott.core.discrete_barycenter.discrete_barycenter
ott.core.discrete_barycenter.discrete_barycenter#
- ott.core.discrete_barycenter.discrete_barycenter(geom, a, weights=None, dual_initialization=None, threshold=0.01, norm_error=1, inner_iterations=10, min_iterations=0, max_iterations=2000, lse_mode=True, debiased=False)[source]#
Compute discrete barycenter using https://arxiv.org/abs/2006.02575.
- Parameters
geom (
Geometry
) – a Cost object able to apply kernels with a certain epsilon.a (
ndarray
) – jnp.ndarray<float>[batch, geom.num_a]: batch of histograms.weights (
Optional
[ndarray
]) – jnp.ndarray of weights in the probability simplexdual_initialization (
Optional
[ndarray
]) – jnp.ndarray, size [batch, num_b] initialization for g_vthreshold (
float
) – (float) tolerance to monitor convergence.norm_error (
int
) – int, power used to define p-norm of error for marginal/target.inner_iterations (
float
) – (int32) the Sinkhorn error is not recomputed at each iteration but every inner_num_iter instead to avoid computational overhead.min_iterations (
int
) – (int32) the minimum number of Sinkhorn iterations carried out before the error is computed and monitored.max_iterations (
int
) – (int32) the maximum number of Sinkhorn iterations.lse_mode (
bool
) – True for log-sum-exp computations, False for kernel multiply.debiased (
bool
) – whether to run the debiased version of the Sinkhorn divergence.
- Return type
Barycenter
- Returns
A
SinkhornBarycenterOutput
, which contains two arrays of potentials, each of sizebatch
timesgeom.num_a
, summarizing the OT between each histogram in the database onto the barycenter, described inhistogram
, as well as a sequence of errors that monitors convergence.