ott.problems.quadratic.quadratic_problem.QuadraticProblem#
- class ott.problems.quadratic.quadratic_problem.QuadraticProblem(geom_xx, geom_yy, geom_xy=None, fused_penalty=1.0, scale_cost=False, a=None, b=None, loss='sqeucl', tau_a=1.0, tau_b=1.0, gw_unbalanced_correction=True, ranks=-1, tolerances=0.01)[source]#
Quadratic OT problem.
The quadratic loss of a single OT matrix is assumed to have the form given in [Peyré et al., 2016], eq. 4.
The two geometries below parameterize matrices \(C\) and \(\bar{C}\) in that equation. The function \(L\) (of two real values) in that equation is assumed to match the form given in eq. 5., with our notations:
\[L(x, y) = lin1(x) + lin2(y) - quad1(x) * quad2(y)\]- Parameters:
geom_xx (
Geometry
) – Ground geometry of the first space.geom_yy (
Geometry
) – Ground geometry of the second space.geom_xy (
Optional
[Geometry
]) – Geometry defining the linear penalty term for Fused Gromov-Wasserstein. If None, the problem reduces to a plain Gromov-Wasserstein problem.fused_penalty (
float
) – multiplier of the linear term in Fused Gromov-Wasserstein, i.e. problem = purely quadratic + fused_penalty * linear problem. Ignored ifgeom_xy
is not specified.scale_cost (
Union
[bool
,float
,str
,None
]) –option to rescale the cost matrices:
a (
Optional
[Array
]) – array representing the probability weights of the samples fromgeom_xx
. If None, it will be uniform.b (
Optional
[Array
]) – array representing the probability weights of the samples fromgeom_yy
. If None, it will be uniform.loss (
Union
[Literal
['sqeucl'
,'kl'
],GWLoss
]) – a 2-tuple of 2-tuples of Callable. The first tuple is the linear part of the loss. The second one is the quadratic part (quad1, quad2). By default, the loss is set as the 4 functions representing the squared Euclidean loss, and this property is taken advantage of in subsequent computations. Alternatively, KL loss can be specified in no less optimized way.tau_a (
Optional
[float
]) – if < 1.0, defines how much unbalanced the problem is on the first marginal.tau_b (
Optional
[float
]) – if < 1.0, defines how much unbalanced the problem is on the second marginal.gw_unbalanced_correction (
bool
) – Whether the unbalanced version of [Sejourne et al., 2021] is used. Otherwise,tau_a
andtau_b
only affect the inner Sinkhorn loop.ranks (
Union
[int
,Tuple
[int
,...
]]) – Ranks of the cost matrices, seeto_LRCGeometry()
. Used when geometries are notPointCloud
with ‘sqeucl’ cost function. If -1, the geometries will not be converted to low-rank. Iftuple
, it specifies the ranks ofgeom_xx
,geom_yy
andgeom_xy
, respectively. Ifint
, rank is shared across all geometries.tolerances (
Union
[float
,Tuple
[float
,...
]]) – Tolerances used when converting geometries to low-rank. Used when geometries are notPointCloud
with ‘sqeucl’ cost. Iffloat
, it is shared across all geometries.
Methods
cost_unbalanced_correction
(transport_matrix, ...)Calculate cost term from the quadratic divergence when unbalanced.
Initialize the transport mass.
marginal_dependent_cost
(marginal_1, ...[, ...])Initialize cost term that depends on the marginals of the transport.
to_low_rank
([rng])Convert geometries to low-rank.
update_linearization
(transport[, epsilon, ...])Update linearization of GW problem by updating cost matrix.
update_lr_geom
(lr_sink[, remove_scale])Recompute (possibly LRC) linearization using LR Sinkhorn output.
update_lr_linearization
(lr_sink, *[, ...])Update a Quad problem linearization using a LR Sinkhorn.
Attributes
First marginal.
Second marginal.
Geometry of the first space.
Geometry of the joint space.
Geometry of the second space.
Whether the problem is balanced.
Whether the problem is fused.
Whether all geometries are low-rank.
Linear part of the Gromov-Wasserstein loss.
Quadratic part of the Gromov-Wasserstein loss.