ott.solvers.quadratic.gromov_wasserstein.solve#
- ott.solvers.quadratic.gromov_wasserstein.solve(geom_xx, geom_yy, geom_xy=None, fused_penalty=1.0, scale_cost=False, a=None, b=None, loss='sqeucl', tau_a=1.0, tau_b=1.0, gw_unbalanced_correction=True, ranks=-1, tolerances=0.01, **kwargs)[source]#
Solve quadratic regularized OT problem.
The quadratic loss of a single OT matrix is assumed to have the form given in [Peyré et al., 2016], eq. 4.
The two geometries below parameterize matrices \(C\) and \(\bar{C}\) in that equation. The function \(L\) (of two real values) in that equation is assumed to match the form given in eq. 5., with our notations:
\[L(x, y) = lin1(x) + lin2(y) - quad1(x) * quad2(y)\]- Parameters:
geom_xx (
Geometry
) – Ground geometry of the first space.geom_yy (
Geometry
) – Ground geometry of the second space.geom_xy (
Optional
[Geometry
]) – Geometry defining the linear penalty term for Fused Gromov-Wasserstein. If None, the problem reduces to a plain Gromov-Wasserstein problem.fused_penalty (
float
) – multiplier of the linear term in Fused Gromov-Wasserstein, i.e. problem = purely quadratic + fused_penalty * linear problem. Ignored ifgeom_xy
is not specified.scale_cost (
Union
[bool
,float
,str
,None
]) –option to rescale the cost matrices:
a (
Optional
[Array
]) – array representing the probability weights of the samples fromgeom_xx
. If None, it will be uniform.b (
Optional
[Array
]) – array representing the probability weights of the samples fromgeom_yy
. If None, it will be uniform.loss (
Union
[Literal
['sqeucl'
,'kl'
],GWLoss
]) – a 2-tuple of 2-tuples of Callable. The first tuple is the linear part of the loss. The second one is the quadratic part (quad1, quad2). By default, the loss is set as the 4 functions representing the squared Euclidean loss, and this property is taken advantage of in subsequent computations. Alternatively, KL loss can be specified in no less optimized way.tau_a (
Optional
[float
]) – if < 1.0, defines how much unbalanced the problem is on the first marginal.tau_b (
Optional
[float
]) – if < 1.0, defines how much unbalanced the problem is on the second marginal.gw_unbalanced_correction (
bool
) – Whether the unbalanced version of [Sejourne et al., 2021] is used. Otherwise,tau_a
andtau_b
only affect the inner Sinkhorn loop.ranks (
Union
[int
,Tuple
[int
,...
]]) – Ranks of the cost matrices, seeto_LRCGeometry()
. Used when geometries are notPointCloud
with ‘sqeucl’ cost function. If -1, the geometries will not be converted to low-rank. Iftuple
, it specifies the ranks ofgeom_xx
,geom_yy
andgeom_xy
, respectively. Ifint
, rank is shared across all geometries.tolerances (
Union
[float
,Tuple
[float
,...
]]) – Tolerances used when converting geometries to low-rank. Used when geometries are notPointCloud
with ‘sqeucl’ cost. Iffloat
, it is shared across all geometries.kwargs (
Any
) – Keyword arguments forGromovWasserstein
.
- Return type:
- Returns:
Gromov-Wasserstein output.