ott.initializers.linear.initializers_lr.GeneralizedKMeansInitializer#
- class ott.initializers.linear.initializers_lr.GeneralizedKMeansInitializer(rank, gamma=10.0, min_iterations=0, max_iterations=100, inner_iterations=10, threshold=1e-06, sinkhorn_kwargs=None)[source]#
Generalized k-means initializer [Scetbon and Cuturi, 2022].
Applicable for any
Geometry
with a square shape.- Parameters:
rank (
int
) – Rank of the factorization.gamma (
float
) – The (inverse of) gradient step size used by mirror descent.min_iterations (
int
) – Minimum number of iterations.max_iterations (
int
) – Maximum number of iterations.inner_iterations (
int
) – Number of iterations used by the algorithm before re-evaluating progress.threshold (
float
) – Convergence threshold.sinkhorn_kwargs (
Optional
[Mapping
[str
,Any
]]) – Keyword arguments forSinkhorn
.
Methods
init_g
(ot_prob, rng, **kwargs)Initialize the low-rank factor \(g\).
init_q
(ot_prob, rng, *, init_g, **kwargs)Initialize the low-rank factor \(Q\).
init_r
(ot_prob, rng, *, init_g, **kwargs)Initialize the low-rank factor \(R\).
Attributes
Rank of the transport matrix factorization.