ott.initializers.quadratic.initializers.QuadraticInitializer

ott.initializers.quadratic.initializers.QuadraticInitializer#

class ott.initializers.quadratic.initializers.QuadraticInitializer(init_coupling=None, **kwargs)[source]#

Initialize a linear problem locally around a selected coupling.

If the problem is balanced (tau_a = 1 and tau_b = 1), the equation of the cost follows eq. 6, p. 1 of [Peyré et al., 2016].

If the problem is unbalanced (tau_a < 1 or tau_b < 1), there are two possible cases. A first possibility is to introduce a quadratic KL divergence on the marginals in the objective as done in [Sejourne et al., 2021] (gw_unbalanced_correction = True), which in turns modifies the local cost matrix.

Alternatively, it could be possible to leave the formulation of the local cost unchanged, i.e. follow eq. 6, p. 1 of [Peyré et al., 2016] (gw_unbalanced_correction = False) and include the unbalanced terms at the level of the linear problem only.

Let \(P\) [num_a, num_b] be the transport matrix, cost_xx is the cost matrix of geom_xx and cost_yy is the cost matrix of geom_yy. left_x and right_y depend on the loss chosen for GW. gw_unbalanced_correction is flag indicating whether the unbalanced correction applies. The equation of the local cost can be written as:

\[\text{marginal_dep_term} + \text{left}_x(\text{cost_xx}) P \text{right}_y(\text{cost_yy}) + \text{unbalanced_correction}\]

When working with the fused problem, a linear term is added to the cost matrix: cost_matrix += fused_penalty * geom_xy.cost_matrix

Parameters:
  • init_coupling (Optional[Array]) – The coupling to use for initialization. If None, defaults to the product coupling \(ab^T\).

  • kwargs (Any) –

Methods