ott.core.quad_problems.QuadraticProblem.cost_unbalanced_correction#

QuadraticProblem.cost_unbalanced_correction(transport_matrix, marginal_1, marginal_2, epsilon, rescale_factor, delta=1e-09)[source]#

Calculate cost term from the quadratic divergence when unbalanced.

In the unbalanced setting (tau_a < 1.0 or tau_b < 1.0), the introduction of a quadratic divergence [Sejourne et al., 2021] adds a term to the GW local cost.

Let \(a\) [num_a,] be the target weights for samples from geom_xx and \(b\) [num_b,] be the target weights for samples from geom_yy. Let \(P\) [num_a, num_b] be the transport matrix, \(P1\) the first marginal and \(P^T1\) the second marginal. The term of the cost matrix coming from the quadratic KL in the unbalanced case can be written as:

unbalanced_correction_term =

\(tau_a / (1 - tau_a) * \sum(KL(P1|a))\) \(+ tau_b / (1 - tau_b) * \sum(KL(P^T1|b))\) \(+ epsilon * \sum(KL(P|ab'))\)

Parameters
  • transport_matrix (ndarray) – jnp.ndarray<float>[num_a, num_b], transport matrix.

  • marginal_1 (ndarray) – jnp.ndarray<float>[num_a,], marginal of the transport matrix for samples from geom_xx.

  • marginal_2 (ndarray) – jnp.ndarray<float>[num_b,], marginal of the transport matrix for samples from geom_yy.

  • epsilon (float) – regulariser.

  • rescale_factor (float) – scaling factor for the transport matrix.

  • delta (float) – small quantity to avoid diverging KLs.

Return type

float

Returns

The cost term.