ott.initializers.linear.initializers_lr.GeneralizedKMeansInitializer

ott.initializers.linear.initializers_lr.GeneralizedKMeansInitializer#

class ott.initializers.linear.initializers_lr.GeneralizedKMeansInitializer(rank, gamma=10.0, min_iterations=0, max_iterations=100, inner_iterations=10, threshold=1e-06, sinkhorn_kwargs=None)[source]#

Generalized k-means initializer [Scetbon and Cuturi, 2022].

Applicable for any Geometry with a square shape.

Parameters:
  • rank (int) – Rank of the factorization.

  • gamma (float) – The (inverse of) gradient step size used by mirror descent.

  • min_iterations (int) – Minimum number of iterations.

  • max_iterations (int) – Maximum number of iterations.

  • inner_iterations (int) – Number of iterations used by the algorithm before re-evaluating progress.

  • threshold (float) – Convergence threshold.

  • sinkhorn_kwargs (Optional[Mapping[str, Any]]) – Keyword arguments for Sinkhorn.

Methods

from_solver(solver, *, kind, **kwargs)

Create a low-rank initializer from a linear or quadratic solver.

init_g(ot_prob, rng, **kwargs)

Initialize the low-rank factor \(g\).

init_q(ot_prob, rng, *, init_g, **kwargs)

Initialize the low-rank factor \(Q\).

init_r(ot_prob, rng, *, init_g, **kwargs)

Initialize the low-rank factor \(R\).

Attributes

rank

Rank of the transport matrix factorization.