ott.neural.losses.monge_gap

Contents

ott.neural.losses.monge_gap#

ott.neural.losses.monge_gap(map_fn, reference_points, cost_fn=None, epsilon=None, relative_epsilon=None, scale_cost=1.0, return_output=False, **kwargs)[source]#

Monge gap regularizer [Uscidda and Cuturi, 2023].

For a cost function \(c\) and empirical reference measure \(\hat{\rho}_n=\frac{1}{n}\sum_{i=1}^n \delta_{x_i}\), the (entropic) Monge gap of a map function \(T:\mathbb{R}^d\rightarrow\mathbb{R}^d\) is defined as:

\[\mathcal{M}^c_{\hat{\rho}_n, \varepsilon} (T) = \frac{1}{n} \sum_{i=1}^n c(x_i, T(x_i)) - W_{c, \varepsilon}(\hat{\rho}_n, T \sharp \hat{\rho}_n)\]

See [Uscidda and Cuturi, 2023] Eq. (8). This function is a thin wrapper that calls monge_gap_from_samples().

Parameters:
  • map_fn (Callable[[Array], Array]) – Callable corresponding to map \(T\) in definition above. The callable should be vectorized (e.g. using jax.vmap()), i.e, able to process a batch of vectors of size d, namely map_fn applied to an array returns an array of the same shape.

  • reference_points (Array) – Array of [n,d] points, \(\hat\rho_n\) in paper

  • cost_fn (Optional[CostFn]) – An object of class CostFn.

  • epsilon (Optional[float]) – Regularization parameter. See PointCloud

  • relative_epsilon (Optional[bool]) – when False, the parameter epsilon specifies the value of the entropic regularization parameter. When True, epsilon refers to a fraction of the mean_cost_matrix, which is computed adaptively using source and target points.

  • scale_cost (Union[int, float, Literal['mean', 'max_cost', 'median']]) – option to rescale the cost matrix. Implemented scalings are ‘median’, ‘mean’ and ‘max_cost’. Alternatively, a float factor can be given to rescale the cost such that cost_matrix /= scale_cost.

  • return_output (bool) – boolean to also return the SinkhornOutput.

  • kwargs (Any) – holds the kwargs to instantiate the or Sinkhorn solver to compute the regularized OT cost.

Return type:

Union[float, Tuple[float, SinkhornOutput]]

Returns:

The Monge gap value and optionally the SinkhornOutput