Calculate cost term from the quadratic divergence when unbalanced.

In the unbalanced setting (tau_a < 1.0 or tau_b < 1.0), the introduction of a quadratic divergence adds a term to the GW local cost.

Let $$a$$ [num_a,] be the target weights for samples from geom_xx and $$b$$ [num_b,] be the target weights for samples from geom_yy. Let $$P$$ [num_a, num_b] be the transport matrix, $$P1$$ the first marginal and $$P^T1$$ the second marginal. The term of the cost matrix coming from the quadratic KL in the unbalanced case can be written as:

unbalanced_correction_term =

$$tau_a / (1 - tau_a) * \sum(KL(P1|a))$$ $$+ tau_b / (1 - tau_b) * \sum(KL(P^T1|b))$$ $$+ epsilon * \sum(KL(P|ab'))$$

Parameters
• transport_matrix (Array) – jnp.ndarray<float>[num_a, num_b], transport matrix.

• marginal_1 (Array) – jnp.ndarray<float>[num_a,], marginal of the transport matrix for samples from geom_xx.

• marginal_2 (Array) – jnp.ndarray<float>[num_b,], marginal of the transport matrix for samples from geom_yy.

• epsilon (Epsilon) – regulariser.

Return type

float

Returns

The cost term.