ott.initializers.quadratic.initializers.QuadraticInitializer
ott.initializers.quadratic.initializers.QuadraticInitializer#
- class ott.initializers.quadratic.initializers.QuadraticInitializer(**kwargs)[source]#
Initialize a linear problem locally around a naive initializer ab’.
If the problem is balanced (
tau_a = 1
andtau_b = 1
), the equation of the cost follows eq. 6, p. 1 of [Peyré et al., 2016].If the problem is unbalanced (
tau_a < 1
ortau_b < 1
), there are two possible cases. A first possibility is to introduce a quadratic KL divergence on the marginals in the objective as done in [Sejourne et al., 2021] (gw_unbalanced_correction = True
), which in turns modifies the local cost matrix.Alternatively, it could be possible to leave the formulation of the local cost unchanged, i.e. follow eq. 6, p. 1 of [Peyré et al., 2016] (
gw_unbalanced_correction = False
) and include the unbalanced terms at the level of the linear problem only.Let \(P\) [num_a, num_b] be the transport matrix, cost_xx is the cost matrix of geom_xx and cost_yy is the cost matrix of geom_yy. left_x and right_y depend on the loss chosen for GW. gw_unbalanced_correction is an boolean indicating whether or not the unbalanced correction applies. The equation of the local cost can be written as:
- cost_matrix = marginal_dep_term
left_x`(`cost_xx) \(P\) right_y`(`cost_yy):math:^T
unbalanced_correction * gw_unbalanced_correction
When working with the fused problem, a linear term is added to the cost matrix: cost_matrix += fused_penalty * geom_xy.cost_matrix
Methods
- Parameters
kwargs (Any) –