ott.problems.quadratic.quadratic_problem.QuadraticProblem#
- class ott.problems.quadratic.quadratic_problem.QuadraticProblem(geom_xx, geom_yy, geom_xy=None, fused_penalty=1.0, scale_cost=None, a=None, b=None, loss='sqeucl', tau_a=1.0, tau_b=1.0, gw_unbalanced_correction=True, ranks=-1, tolerances=0.01)[source]#
Quadratic OT problem.
The quadratic loss of a single OT matrix is assumed to have the form given in [Peyré et al., 2016], eq. 4.
The two geometries below parameterize matrices \(C\) and \(\bar{C}\) in that equation. The function \(L\) (of two real values) in that equation is assumed to match the form given in eq. 5., with our notations:
\[L(x, y) = f_1(x) + f_2(y) - h_1(x) h_2(y)\]- Parameters:
geom_xx (
Geometry
) – Ground geometry of the first space.geom_yy (
Geometry
) – Ground geometry of the second space.geom_xy (
Optional
[Geometry
]) – Geometry defining the linear penalty term for fused Gromov-Wasserstein [Titouan et al., 2019]. IfNone
, the problem reduces to a plain Gromov-Wasserstein problem [Peyré et al., 2016].fused_penalty (
float
) – Multiplier of the linear term in fused Gromov-Wasserstein, i.e.problem = purely quadratic + fused_penalty * linear problem
.scale_cost (
Union
[float
,str
,None
]) – How to rescale the cost matrices. If astr
, use specific options available inGeometry
orPointCloud
. IfNone
, keep the original scaling.a (
Optional
[Array
]) – The first marginal. IfNone
, it will be uniform.b (
Optional
[Array
]) – The second marginal. IfNone
, it will be uniform.loss (
Union
[Literal
['sqeucl'
,'kl'
],GWLoss
]) – Gromov-Wasserstein loss function, seeGWLoss
for more information.tau_a (
float
) – If \(< 1.0\), defines how much unbalanced the problem is on the first marginal.tau_b (
float
) – If \(< 1.0\), defines how much unbalanced the problem is on the second marginal.gw_unbalanced_correction (
bool
) – Whether the unbalanced version of [Sejourne et al., 2021] is used. Otherwise,tau_a
andtau_b
only affect the inner Sinkhorn loop.ranks (
Union
[int
,Tuple
[int
,...
]]) – Ranks of the cost matrices, seeto_LRCGeometry()
. Used when geometries are notPointCloud
with ‘sqeucl’ cost function. If -1, the geometries will not be converted to low-rank. Iftuple
, it specifies the ranks ofgeom_xx
,geom_yy
andgeom_xy
, respectively. Ifint
, rank is shared across all geometries.tolerances (
Union
[float
,Tuple
[float
,...
]]) – Tolerances used when converting geometries to low-rank. Used when geometries are notPointCloud
with ‘sqeucl’ cost. Iffloat
, it is shared across all geometries.
Methods
cost_unbalanced_correction
(transport_matrix, ...)Calculate cost term from the quadratic divergence when unbalanced.
Initialize the transport mass.
marginal_dependent_cost
(marginal_1, marginal_2)Initialize cost term that depends on the marginals of the transport.
to_low_rank
([rng])Convert geometries to low-rank.
update_linearization
(transport[, epsilon, ...])Update linearization of GW problem by updating cost matrix.
update_lr_geom
(lr_sink[, relative_epsilon])Recompute (possibly LRC) linearization using LR Sinkhorn output.
update_lr_linearization
(lr_sink, *[, ...])Update a Quad problem linearization using a LR Sinkhorn.
Attributes
First marginal.
Second marginal.
Geometry of the first space.
Geometry of the joint space.
Geometry of the second space.
Whether the problem is balanced.
Whether the problem is fused.
Whether all geometries are low-rank.
Linear part of the Gromov-Wasserstein loss.
Quadratic part of the Gromov-Wasserstein loss.