Initialise a linear problem locally around a naive initializer ab’.

If the problem is balanced (tau_a = 1.0 and tau_b = 1.0), the equation of the cost follows eq. 6, p. 1 of .

If the problem is unbalanced (tau_a<1.0 or tau_b<1.0), there are two possible cases. A first possibility is to introduce a quadratic KL divergence on the marginals in the objective as done in (gw_unbalanced_correction = True), which in turns modifies the local cost matrix.

Alternatively, it could be possible to leave the formulation of the local cost unchanged, i.e. follow eq. 6, p. 1 of (gw_unbalanced_correction = False) and include the unbalanced terms at the level of the linear problem only.

Let $$P$$ [num_a, num_b] be the transport matrix, cost_xx is the cost matrix of geom_xx and cost_yy is the cost matrix of geom_yy. left_x and right_y depend on the loss chosen for GW. gw_unbalanced_correction is an boolean indicating whether or not the unbalanced correction applies. The equation of the local cost can be written as:

cost_matrix = marginal_dep_term
• left_x(cost_xx) $$P$$ right_y(cost_yy):math:^T

• unbalanced_correction * gw_unbalanced_correction

When working with the fused problem, a linear term is added to the cost matrix: cost_matrix += fused_penalty * geom_xy.cost_matrix

Parameters

epsilon (Union[Epsilon, float, None]) – An epsilon scheduler or a float passed on to the linearization.

Return type

LinearProblem

Returns

A linear_problems.LinearProblem, representing local linearization of GW problem.