ott.initializers.nn.initializers.MetaInitializer.update#

MetaInitializer.update(state, a, b)[source]#

Update the meta model with the dual objective.

The goal is for the model to match the optimal duals, i.e., \(\hat f_\theta \approx f^\star\). This can be done by training the predictions of \(\hat f_\theta\) to optimize the dual objective, which \(f^\star\) also optimizes for. The overall learning setup can thus be written as:

\[\min_\theta\; {\mathbb E}_{(\alpha,\beta)\sim{\mathcal{D}}}\; J(\hat f_\theta(a, b); \alpha, \beta),\]

where \(a,b\) are the probabilities of the measures \(\alpha,\beta\) ,:math:mathcal{D} is a meta distribution of optimal transport problems,

\[-J(f; \alpha, \beta, c) := \langle f, a\rangle + \langle g, b \rangle - \varepsilon\left\langle \exp\{f/\varepsilon\}, K\exp\{g/\varepsilon\} \right\rangle\]

is the entropic dual objective, and \(K_{i,j} := -C_{i,j}/\varepsilon\) is the Gibbs kernel.

Parameters:
  • state (TrainState) – Optimizer state of the meta model.

  • a (Array) – Probabilities of the \(\alpha\) measure’s atoms.

  • b (Array) – Probabilities of the \(\beta\) measure’s atoms.

Return type:

Tuple[Array, Array, TrainState]

Returns:

The training loss, \(f\), and updated state.