ott.core.quad_problems.QuadraticProblem
ott.core.quad_problems.QuadraticProblem#
- class ott.core.quad_problems.QuadraticProblem(geom_xx, geom_yy, geom_xy=None, fused_penalty=1.0, scale_cost=False, a=None, b=None, loss=None, tau_a=1.0, tau_b=1.0, gw_unbalanced_correction=True)[source]#
Definition of the quadratic regularized OT problem.
The quadratic loss of a single OT matrix is assumed to have the form given in Eq. 4 from
http://proceedings.mlr.press/v48/peyre16.pdf
The two geometries below parameterize matrices C and bar{C} in that equation. The function L (of two real values) in that equation is assumed to match the form given in Eq. 5., with our notations:
L(x, y) = lin1(x) + lin2(y) - quad1(x) * quad2(y)
- Parameters
geom_xx (
Geometry
) – the geometry.Geometry object defining the ground geometry / cost of the first space.geom_yy (
Geometry
) – the geometry.Geometry object defining the ground geometry / cost of the second space.geom_xy (
Optional
[Geometry
]) – the geometry.Geometry object defining the linear penalty term for Fused Gromov Wasserstein. If None, the problem reduces to a plain Gromov Wasserstein problem.fused_penalty (
float
) – multiplier of the linear term in Fused Gromov Wasserstein, i.e. problem = purely quadratic + fused_penalty * linear problem. Ignored ifgeom_xy
is not specified.scale_cost (
Union
[bool
,float
,str
,None
]) –option to rescale the cost matrices:
if True, use the default for each geometry.
if False, keep the original scaling in geometries.
if
str
, use a specific method available inott.geometry.geometry.Geometry
or :class`ott.geometry.pointcloud.PointCloud`.if None, do not scale the cost matrices.
a (
Optional
[ndarray
]) – jnp.ndarray[n] representing the probability weights of the samples from geom_xx. If None, it will be uniform.b (
Optional
[ndarray
]) – jnp.ndarray[n] representing the probability weights of the samples from geom_yy. If None, it will be uniform.loss (
Optional
[Tuple
[Tuple
[Callable
[[ndarray
],ndarray
],Callable
[[ndarray
],ndarray
]],Tuple
[Callable
[[ndarray
],ndarray
],Callable
[[ndarray
],ndarray
]]]]) – a 2-tuple of 2-tuples of Callable. The first tuple is the linear part of the loss (see in the pydoc of the class lin1, lin2). The second one is the quadratic part (quad1, quad2). If None is passed, the loss is set as the 4 functions representing the squared euclidean loss, and this property is taken advantage of in subsequent computations. See make_kl_loss for an alternative, no less optimized way of setting the loss.tau_a (
Optional
[float
]) – if lower that 1.0, defines how much unbalanced the problem is on the first marginal.tau_b (
Optional
[float
]) – if lower that 1.0, defines how much unbalanced the problem is on the second marginal.gw_unbalanced_correction (
Optional
[bool
]) – True (default) if the unbalanced version of Sejourne et al. (Neurips 2021) is used, False if tau_a and tau_b only affect the inner Sinhkorn loop.
Methods
cost_unbalanced_correction
(transport_matrix, ...)Calculate cost term from the quadratic divergence when unbalanced.
init_linearization
([epsilon])Initialise a linear problem locally around a naive initializer ab'.
init_lr_linearization
([rank])Linearizes a Quad problem with a predefined initializer.
Initialise the transport matrix.
Initialise the transport mass.
marginal_dependent_cost
(marginal_1, marginal_2)Initialise cost term that depends on the marginals of the transport.
update_linearization
(transport[, epsilon, ...])Update linearization of GW problem by updating cost matrix.
update_lr_geom
(lr_sink)Recompute (possibly LRC) linearization using LR Sinkhorn output.
update_lr_linearization
(lr_sink)Update a Quad problem linearization using a LR Sinkhorn.
Attributes
- rtype
- rtype
- rtype
- rtype
- rtype