ott.geometry.costs.Bures.covariance_fixpoint_iter

ott.geometry.costs.Bures.covariance_fixpoint_iter#

Bures.covariance_fixpoint_iter(covs, weights, tolerance=0.0001, sqrtm_kw=None, **kwargs)[source]#

Iterate fix-point updates to compute barycenter of Gaussians.

Parameters:
  • covs (Array) – [batch, d^2] covariance matrices

  • weights (Array) – simplicial weights (non-negative, sum to 1)

  • tolerance (float) – tolerance of the fixed-point procedure. That tolerance is applied to the Frobenius norm (normalized by total size) of two successive iterations of the algorithm

  • sqrtm_kw (Optional[Dict[str, Any]]) – keyword arguments for sqrtm()

  • kwargs (Any) – keyword arguments for the outer fixed-point iteration

Return type:

Array

Returns:

List containing Weighted Bures average of the covariance matrices, and vector of (normalized) 2-norms of successive differences between iterates, to monitor convergence.