ott.core.quad_problems.QuadraticProblem.marginal_dependent_cost#

QuadraticProblem.marginal_dependent_cost(marginal_1, marginal_2)[source]#

Initialise cost term that depends on the marginals of the transport.

Uses the first term in eq. 6, p. 1 of [Peyré et al., 2016].

Let \(p\) [num_a,] be the marginal of the transport matrix for samples from geom_xx and \(q\) [num_b,] be the marginal of the transport matrix for samples from geom_yy. cost_xx (resp. cost_yy) is the cost matrix of geom_xx (resp. geom_yy). The cost term that depends on these marginals can be written as:

marginal_dep_term = lin1`(`cost_xx) \(p \mathbb{1}_{num_b}^T\)
  • (lin2`(`cost_yy) \(q \mathbb{1}_{num_a}^T)^T\)

Parameters
  • marginal_1 (ndarray) – jnp.ndarray<float>[num_a,], marginal of the transport matrix for samples from geom_xx

  • marginal_2 (ndarray) – jnp.ndarray<float>[num_b,], marginal of the transport matrix for samples from geom_yy

Return type

LRCGeometry

Returns

Low-rank geometry.