ott.geometry.geometry.Geometry.to_LRCGeometry

ott.geometry.geometry.Geometry.to_LRCGeometry#

Geometry.to_LRCGeometry(rank=0, tol=0.01, rng=None, scale=1.0)[source]#

Factorize the cost matrix using either SVD (full) or [Indyk et al., 2019].

When rank=min(n,m) or 0 (by default), use jax.numpy.linalg.svd().

For other values, use the routine in sublinear time [Indyk et al., 2019]. Uses the implementation of [Scetbon et al., 2021], algorithm 4.

It holds that with probability 0.99, \(||A - UV||_F^2 \leq || A - A_k ||_F^2 + tol \cdot ||A||_F^2\), where \(A\) is n x m cost matrix, \(UV\) the factorization computed in sublinear time and \(A_k\) the best rank-k approximation.

Parameters:
  • rank (int) – Target rank of the cost_matrix.

  • tol (float) – Tolerance of the error. The total number of sampled points is \(min(n, m,\frac{rank}{tol})\).

  • rng (Optional[Array]) – The PRNG key to use for initializing the model.

  • scale (float) – Value used to rescale the factors of the low-rank geometry. Useful when this geometry is used in the linear term of fused GW.

Return type:

LRCGeometry

Returns:

Low-rank geometry.