ott.initializers.nn.initializers.MetaInitializer.update#
- MetaInitializer.update(state, a, b)[source]#
Update the meta model with the dual objective.
The goal is for the model to match the optimal duals, i.e., \(\hat f_\theta \approx f^\star\). This can be done by training the predictions of \(\hat f_\theta\) to optimize the dual objective, which \(f^\star\) also optimizes for. The overall learning setup can thus be written as:
\[\min_\theta\; {\mathbb E}_{(\alpha,\beta)\sim{\mathcal{D}}}\; J(\hat f_\theta(a, b); \alpha, \beta),\]where \(a,b\) are the probabilities of the measures \(\alpha,\beta\) ,:math:mathcal{D} is a meta distribution of optimal transport problems,
\[-J(f; \alpha, \beta, c) := \langle f, a\rangle + \langle g, b \rangle - \varepsilon\left\langle \exp\{f/\varepsilon\}, K\exp\{g/\varepsilon\} \right\rangle\]is the entropic dual objective, and \(K_{i,j} := -C_{i,j}/\varepsilon\) is the Gibbs kernel.
- Parameters:
state (
TrainState
) – Optimizer state of the meta model.a (
Array
) – Probabilities of the \(\alpha\) measure’s atoms.b (
Array
) – Probabilities of the \(\beta\) measure’s atoms.
- Return type:
- Returns:
The training loss, \(f\), and updated state.