ott.problems.quadratic.quadratic_problem.QuadraticProblem.marginal_dependent_cost
ott.problems.quadratic.quadratic_problem.QuadraticProblem.marginal_dependent_cost#
- QuadraticProblem.marginal_dependent_cost(marginal_1, marginal_2)[source]#
Initialise cost term that depends on the marginals of the transport.
Uses the first term in eq. 6, p. 1 of [Peyré et al., 2016].
Let \(p\) [num_a,] be the marginal of the transport matrix for samples from geom_xx and \(q\) [num_b,] be the marginal of the transport matrix for samples from geom_yy. cost_xx (resp. cost_yy) is the cost matrix of geom_xx (resp. geom_yy). The cost term that depends on these marginals can be written as:
- marginal_dep_term = lin1`(`cost_xx) \(p \mathbb{1}_{num_b}^T\)
(lin2`(`cost_yy) \(q \mathbb{1}_{num_a}^T)^T\)
- Parameters
marginal_1 (
Array
) – jnp.ndarray<float>[num_a,], marginal of the transport matrix for samples from geom_xxmarginal_2 (
Array
) – jnp.ndarray<float>[num_b,], marginal of the transport matrix for samples from geom_yy
- Return type
- Returns
Low-rank geometry.