ott.solvers.quadratic.solve#
- ott.solvers.quadratic.solve(geom_xx, geom_yy, geom_xy=None, fused_penalty=1.0, a=None, b=None, tau_a=1.0, tau_b=1.0, loss='sqeucl', gw_unbalanced_correction=True, rank=-1, linear_solver_kwargs=None, **kwargs)[source]#
Solve quadratic regularized OT problem using a Gromov-Wasserstein solver.
- Parameters:
geom_xx (
Geometry
) – Ground geometry of the first space.geom_yy (
Geometry
) – Ground geometry of the second space.geom_xy (
Optional
[Geometry
]) – Geometry defining the linear penalty term for fused Gromov-Wasserstein [Titouan et al., 2019]. IfNone
, the problem reduces to a plain Gromov-Wasserstein problem [Peyré et al., 2016].fused_penalty (
float
) – Multiplier of the linear term in fused Gromov-Wasserstein, i.e.problem = purely quadratic + fused_penalty * linear problem
.a (
Optional
[Array
]) – The first marginal. IfNone
, it will be uniform.b (
Optional
[Array
]) – The second marginal. IfNone
, it will be uniform.tau_a (
float
) – If \(< 1\), defines how much unbalanced the problem is on the first marginal.tau_b (
float
) – If \(< 1\), defines how much unbalanced the problem is on the second marginal.loss (
Union
[Literal
['sqeucl'
,'kl'
],GWLoss
]) – Gromov-Wasserstein loss function, seeGWLoss
for more information. Ifrank > 0
,'sqeucl'
is always used.gw_unbalanced_correction (
bool
) – Whether the unbalanced version of [Sejourne et al., 2021] is used. Otherwise,tau_a
andtau_b
only affect the resolution of the linearization of the GW problem in the inner loop. Only used whenrank = -1
.rank (
int
) – Rank constraint on the coupling to minimize the quadratic OT problem [Scetbon et al., 2022]. If \(-1\), no rank constraint is used.linear_solver_kwargs (
Optional
[Dict
[str
,Any
]]) – Keyword arguments forSinkhorn
, ifrank > 0
.kwargs (
Any
) – Keyword arguments forGromovWasserstein
orLRGromovWasserstein
, depending on therank
- Return type:
- Returns:
The Gromov-Wasserstein output.