ott.neural.methods.monge_gap.monge_gap#
- ott.neural.methods.monge_gap.monge_gap(map_fn, reference_points, cost_fn=None, epsilon=None, relative_epsilon=None, scale_cost=1.0, return_output=False, **kwargs)[source]#
Monge gap regularizer [Uscidda and Cuturi, 2023].
For a cost function \(c\) and empirical reference measure \(\hat{\rho}_n=\frac{1}{n}\sum_{i=1}^n \delta_{x_i}\), the (entropic) Monge gap of a map function \(T:\mathbb{R}^d\rightarrow\mathbb{R}^d\) is defined as:
\[\mathcal{M}^c_{\hat{\rho}_n, \varepsilon} (T) = \frac{1}{n} \sum_{i=1}^n c(x_i, T(x_i)) - W_{c, \varepsilon}(\hat{\rho}_n, T \sharp \hat{\rho}_n)\]See [Uscidda and Cuturi, 2023] Eq. (8). This function is a thin wrapper that calls
monge_gap_from_samples()
.- Parameters:
map_fn (
Callable
[[Array
],Array
]) – Callable corresponding to map \(T\) in definition above. The callable should be vectorized (e.g. usingvmap()
), i.e, able to process a batch of vectors of size d, namelymap_fn
applied to an array returns an array of the same shape.reference_points (
Array
) – Array of [n,d] points, \(\hat\rho_n\).epsilon (
Optional
[float
]) – Regularization parameter. SeePointCloud
relative_epsilon (
Optional
[Literal
['mean'
,'std'
]]) – when False, the parameterepsilon
specifies the value of the entropic regularization parameter. When True,epsilon
refers to a fraction of themean_cost_matrix
, which is computed adaptively usingsource
andtarget
points.scale_cost (
Union
[float
,Literal
['mean'
,'max_cost'
,'median'
]]) – option to rescale the cost matrix. Implemented scalings are ‘median’, ‘mean’ and ‘max_cost’. Alternatively, a float factor can be given to rescale the cost such thatcost_matrix /= scale_cost
.return_output (
bool
) – boolean to also return theSinkhornOutput
.kwargs (
Any
) – holds the kwargs to instantiate the orSinkhorn
solver to compute the regularized OT cost.
- Return type:
- Returns:
The Monge gap value and optionally the
SinkhornOutput