ott.problems.quadratic.quadratic_problem.QuadraticProblem.cost_unbalanced_correction#

QuadraticProblem.cost_unbalanced_correction(transport_matrix, marginal_1, marginal_2, epsilon)[source]#

Calculate cost term from the quadratic divergence when unbalanced.

In the unbalanced setting (tau_a < 1.0 or tau_b < 1.0), the introduction of a quadratic divergence [Sejourne et al., 2021] adds a term to the GW local cost.

Let \(a\) [num_a,] be the target weights for samples from geom_xx and \(b\) [num_b,] be the target weights for samples from geom_yy. Let \(P\) [num_a, num_b] be the transport matrix, \(P1\) the first marginal and \(P^T1\) the second marginal. The term of the cost matrix coming from the quadratic KL in the unbalanced case can be written as:

unbalanced_correction_term =

\(tau_a / (1 - tau_a) * \sum(KL(P1|a))\) \(+ tau_b / (1 - tau_b) * \sum(KL(P^T1|b))\) \(+ epsilon * \sum(KL(P|ab'))\)

Parameters
  • transport_matrix (Array) – jnp.ndarray<float>[num_a, num_b], transport matrix.

  • marginal_1 (Array) – jnp.ndarray<float>[num_a,], marginal of the transport matrix for samples from geom_xx.

  • marginal_2 (Array) – jnp.ndarray<float>[num_b,], marginal of the transport matrix for samples from geom_yy.

  • epsilon (Epsilon) – regulariser.

Return type

float

Returns

The cost term.