ott.neural.methods.monge_gap.monge_gap#
- ott.neural.methods.monge_gap.monge_gap(map_fn, reference_points, cost_fn=None, epsilon=None, relative_epsilon=None, scale_cost=1.0, return_output=False, **kwargs)[source]#
Monge gap regularizer [Uscidda and Cuturi, 2023].
For a cost function \(c\) and empirical reference measure \(\hat{\rho}_n=\frac{1}{n}\sum_{i=1}^n \delta_{x_i}\), the (entropic) Monge gap of a map function \(T:\mathbb{R}^d\rightarrow\mathbb{R}^d\) is defined as:
\[\mathcal{M}^c_{\hat{\rho}_n, \varepsilon} (T) = \frac{1}{n} \sum_{i=1}^n c(x_i, T(x_i)) - W_{c, \varepsilon}(\hat{\rho}_n, T \sharp \hat{\rho}_n)\]See [Uscidda and Cuturi, 2023] Eq. (8). This function is a thin wrapper that calls
monge_gap_from_samples().- Parameters:
map_fn (
Callable[[Array],Array]) – Callable corresponding to map \(T\) in definition above. The callable should be vectorized (e.g. usingvmap()), i.e, able to process a batch of vectors of size d, namelymap_fnapplied to an array returns an array of the same shape.reference_points (
Array) – Array of [n,d] points, \(\hat\rho_n\).epsilon (
Optional[float]) – Regularization parameter. SeePointCloudrelative_epsilon (
Optional[Literal['mean','std']]) – when False, the parameterepsilonspecifies the value of the entropic regularization parameter. When True,epsilonrefers to a fraction of themean_cost_matrix, which is computed adaptively usingsourceandtargetpoints.scale_cost (
Union[float,Literal['mean','max_cost','median']]) – option to rescale the cost matrix. Implemented scalings are ‘median’, ‘mean’ and ‘max_cost’. Alternatively, a float factor can be given to rescale the cost such thatcost_matrix /= scale_cost.return_output (
bool) – boolean to also return theSinkhornOutput.kwargs (
Any) – holds the kwargs to instantiate the orSinkhornsolver to compute the regularized OT cost.
- Return type:
Union[float,Tuple[float,SinkhornOutput]]- Returns:
The Monge gap value and optionally the
SinkhornOutput