ott.solvers.quadratic.solve#
- ott.solvers.quadratic.solve(geom_xx, geom_yy, geom_xy=None, fused_penalty=1.0, a=None, b=None, tau_a=1.0, tau_b=1.0, loss='sqeucl', gw_unbalanced_correction=True, rank=-1, linear_solver_kwargs=None, **kwargs)[source]#
Solve quadratic regularized OT problem using a Gromov-Wasserstein solver.
- Parameters:
geom_xx (
Geometry) – Ground geometry of the first space.geom_yy (
Geometry) – Ground geometry of the second space.geom_xy (
Optional[Geometry]) – Geometry defining the linear penalty term for fused Gromov-Wasserstein [Titouan et al., 2019]. IfNone, the problem reduces to a plain Gromov-Wasserstein problem [Peyré et al., 2016].fused_penalty (
float) – Multiplier of the linear term in fused Gromov-Wasserstein, i.e.problem = purely quadratic + fused_penalty * linear problem.a (
Optional[Array]) – The first marginal. IfNone, it will be uniform.b (
Optional[Array]) – The second marginal. IfNone, it will be uniform.tau_a (
float) – If \(< 1\), defines how much unbalanced the problem is on the first marginal.tau_b (
float) – If \(< 1\), defines how much unbalanced the problem is on the second marginal.loss (
Union[Literal['sqeucl','kl'],GWLoss]) – Gromov-Wasserstein loss function, seeGWLossfor more information. Ifrank > 0,'sqeucl'is always used.gw_unbalanced_correction (
bool) – Whether the unbalanced version of [Sejourne et al., 2021] is used. Otherwise,tau_aandtau_bonly affect the resolution of the linearization of the GW problem in the inner loop. Only used whenrank = -1.rank (
int) – Rank constraint on the coupling to minimize the quadratic OT problem [Scetbon et al., 2022]. If \(-1\), no rank constraint is used.linear_solver_kwargs (
Optional[Dict[str,Any]]) – Keyword arguments forSinkhorn, ifrank > 0.kwargs (
Any) – Keyword arguments forGromovWassersteinorLRGromovWasserstein, depending on therank
- Return type:
Union[GWOutput,LRGWOutput]- Returns:
The Gromov-Wasserstein output.