ott.initializers.nn.initializers.MetaInitializer#
- class ott.initializers.nn.initializers.MetaInitializer(geom, meta_model=None, opt=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn>, update=<function chain.<locals>.update_fn>), rng=None, state=None)[source]#
Meta OT Initializer with a fixed geometry [Amos et al., 2022].
This initializer consists of a predictive model that outputs the \(f\) duals to solve the entropy-regularized OT problem given input probability weights
a
andb
, and a given (assumed to be fixed) geometrygeom
.The model’s parameters are learned using a training set of OT instances (multiple pairs of probability weights), that assume the same geometry
geom
is used throughout, both for training and evaluation. The meta model defaults to the MLP inMetaMLP
and, with batched problem instances passed intoupdate()
.- Parameters:
geom (
Geometry
) – The fixed geometry of the problem instances.meta_model (
Optional
[Module
]) – The model to predict the potential \(f\) from the measures.opt (
Optional
[GradientTransformation
]) – The optimizer to update the parameters. IfNone
, useoptax.adam()
with \(0.001\) learning rate.rng (
Optional
[PRNGKeyArray
]) – The PRNG key to use for initializing the model.state (
Optional
[TrainState
]) – The training state of the model to start from.
Examples
The following code shows a simple example of using
update
to train the model, wherea
andb
are the weights of the measures andgeom
is the fixed geometry.meta_initializer = init_lib.MetaInitializer(geom) while training(): a, b = sample_batch() loss, init_f, meta_initializer.state = meta_initializer.update( meta_initializer.state, a=a, b=b )
Methods
init_dual_a
(ot_prob, lse_mode[, rng])Initialize Sinkhorn potential/scaling f_u.
init_dual_b
(ot_prob, lse_mode[, rng])Initialize Sinkhorn potential/scaling g_v.
update
(state, a, b)Update the meta model with the dual objective.