ott.core.quad_problems.QuadraticProblem#

class ott.core.quad_problems.QuadraticProblem(geom_xx, geom_yy, geom_xy=None, fused_penalty=1.0, scale_cost=False, a=None, b=None, loss=None, tau_a=1.0, tau_b=1.0, gw_unbalanced_correction=True)[source]#

Definition of the quadratic regularized OT problem.

The quadratic loss of a single OT matrix is assumed to have the form given in Eq. 4 from

http://proceedings.mlr.press/v48/peyre16.pdf

The two geometries below parameterize matrices C and bar{C} in that equation. The function L (of two real values) in that equation is assumed to match the form given in Eq. 5., with our notations:

L(x, y) = lin1(x) + lin2(y) - quad1(x) * quad2(y)

Parameters
  • geom_xx (Geometry) – the geometry.Geometry object defining the ground geometry / cost of the first space.

  • geom_yy (Geometry) – the geometry.Geometry object defining the ground geometry / cost of the second space.

  • geom_xy (Optional[Geometry]) – the geometry.Geometry object defining the linear penalty term for Fused Gromov Wasserstein. If None, the problem reduces to a plain Gromov Wasserstein problem.

  • fused_penalty (float) – multiplier of the linear term in Fused Gromov Wasserstein, i.e. problem = purely quadratic + fused_penalty * linear problem. Ignored if geom_xy is not specified.

  • scale_cost (Union[bool, float, str, None]) –

    option to rescale the cost matrices:

    • if True, use the default for each geometry.

    • if False, keep the original scaling in geometries.

    • if str, use a specific method available in ott.geometry.geometry.Geometry or :class`ott.geometry.pointcloud.PointCloud`.

    • if None, do not scale the cost matrices.

  • a (Optional[ndarray]) – jnp.ndarray[n] representing the probability weights of the samples from geom_xx. If None, it will be uniform.

  • b (Optional[ndarray]) – jnp.ndarray[n] representing the probability weights of the samples from geom_yy. If None, it will be uniform.

  • loss (Optional[Tuple[Tuple[Callable[[ndarray], ndarray], Callable[[ndarray], ndarray]], Tuple[Callable[[ndarray], ndarray], Callable[[ndarray], ndarray]]]]) – a 2-tuple of 2-tuples of Callable. The first tuple is the linear part of the loss (see in the pydoc of the class lin1, lin2). The second one is the quadratic part (quad1, quad2). If None is passed, the loss is set as the 4 functions representing the squared euclidean loss, and this property is taken advantage of in subsequent computations. See make_kl_loss for an alternative, no less optimized way of setting the loss.

  • tau_a (Optional[float]) – if lower that 1.0, defines how much unbalanced the problem is on the first marginal.

  • tau_b (Optional[float]) – if lower that 1.0, defines how much unbalanced the problem is on the second marginal.

  • gw_unbalanced_correction (Optional[bool]) – True (default) if the unbalanced version of Sejourne et al. (Neurips 2021) is used, False if tau_a and tau_b only affect the inner Sinhkorn loop.

Methods

cost_unbalanced_correction(transport_matrix, ...)

Calculate cost term from the quadratic divergence when unbalanced.

init_linearization([epsilon])

Initialise a linear problem locally around a naive initializer ab'.

init_lr_linearization([rank])

Linearizes a Quad problem with a predefined initializer.

init_transport()

Initialise the transport matrix.

init_transport_mass()

Initialise the transport mass.

marginal_dependent_cost(marginal_1, marginal_2)

Initialise cost term that depends on the marginals of the transport.

update_linearization(transport[, epsilon, ...])

Update linearization of GW problem by updating cost matrix.

update_lr_geom(lr_sink)

Recompute (possibly LRC) linearization using LR Sinkhorn output.

update_lr_linearization(lr_sink)

Update a Quad problem linearization using a LR Sinkhorn.

Attributes

a

rtype

ndarray

b

rtype

ndarray

is_all_geoms_lr

rtype

bool

is_balanced

rtype

bool

is_fused

rtype

bool

linear_loss

rtype

Tuple[Callable[[ndarray], ndarray], Callable[[ndarray], ndarray]]

quad_loss

rtype

Tuple[Callable[[ndarray], ndarray], Callable[[ndarray], ndarray]]