QuadraticProblem.cost_unbalanced_correction(transport_matrix, marginal_1, marginal_2, epsilon, rescale_factor, delta=1e-09)[source]#

Calculate cost term from the quadratic divergence when unbalanced.

In the unbalanced setting (tau_a < 1.0 or tau_b < 1.0), the introduction of a quadratic divergence adds a term to the GW local cost.

Let $$a$$ [num_a,] be the target weights for samples from geom_xx and $$b$$ [num_b,] be the target weights for samples from geom_yy. Let $$P$$ [num_a, num_b] be the transport matrix, $$P1$$ the first marginal and $$P^T1$$ the second marginal. The term of the cost matrix coming from the quadratic KL in the unbalanced case can be written as:

unbalanced_correction_term =

$$tau_a / (1 - tau_a) * \sum(KL(P1|a))$$ $$+ tau_b / (1 - tau_b) * \sum(KL(P^T1|b))$$ $$+ epsilon * \sum(KL(P|ab'))$$

Parameters
Return type

float

Returns

The cost term.