ott.problems.quadratic.quadratic_problem.QuadraticProblem.cost_unbalanced_correction#
- QuadraticProblem.cost_unbalanced_correction(transport_matrix, marginal_1, marginal_2, epsilon)[source]#
Calculate cost term from the quadratic divergence when unbalanced.
In the unbalanced setting (
tau_a < 1.0 or tau_b < 1.0
), the introduction of a quadratic divergence [Sejourne et al., 2021] adds a term to the GW local cost.Let \(a\) [num_a,] be the target weights for samples from geom_xx and \(b\) [num_b,] be the target weights for samples from geom_yy. Let \(P\) [num_a, num_b] be the transport matrix, \(P1\) the first marginal and \(P^T1\) the second marginal. The term of the cost matrix coming from the quadratic KL in the unbalanced case can be written as:
- unbalanced_correction_term =
\(tau_a / (1 - tau_a) * \sum(KL(P1|a))\) \(+ tau_b / (1 - tau_b) * \sum(KL(P^T1|b))\) \(+ epsilon * \sum(KL(P|ab'))\)
- Parameters:
transport_matrix (
Array
) – jnp.ndarray<float>[num_a, num_b], transport matrix.marginal_1 (
Array
) – jnp.ndarray<float>[num_a,], marginal of the transport matrix for samples fromgeom_xx
.marginal_2 (
Array
) – jnp.ndarray<float>[num_b,], marginal of the transport matrix for samples fromgeom_yy
.epsilon (
float
) – entropy regularizer.
- Return type:
- Returns:
The cost term.