ott.initializers.nn.initializers.MetaInitializer#

class ott.initializers.nn.initializers.MetaInitializer(geom, meta_model=None, opt=GradientTransformationExtraArgs(init=<function chain.<locals>.init_fn>, update=<function chain.<locals>.update_fn>), rng=None, state=None)[source]#

Meta OT Initializer with a fixed geometry [Amos et al., 2022].

This initializer consists of a predictive model that outputs the \(f\) duals to solve the entropy-regularized OT problem given input probability weights a and b, and a given (assumed to be fixed) geometry geom.

The model’s parameters are learned using a training set of OT instances (multiple pairs of probability weights), that assume the same geometry geom is used throughout, both for training and evaluation. The meta model defaults to the MLP in MetaMLP and, with batched problem instances passed into update().

Parameters:
  • geom (Geometry) – The fixed geometry of the problem instances.

  • meta_model (Optional[Module]) – The model to predict the potential \(f\) from the measures.

  • opt (Optional[GradientTransformation]) – The optimizer to update the parameters. If None, use optax.adam() with \(0.001\) learning rate.

  • rng (Optional[PRNGKeyArray]) – The PRNG key to use for initializing the model.

  • state (Optional[TrainState]) – The training state of the model to start from.

Examples

The following code shows a simple example of using update to train the model, where a and b are the weights of the measures and geom is the fixed geometry.

meta_initializer = init_lib.MetaInitializer(geom)
while training():
  a, b = sample_batch()
  loss, init_f, meta_initializer.state = meta_initializer.update(
    meta_initializer.state, a=a, b=b
  )

Methods

init_dual_a(ot_prob, lse_mode[, rng])

Initialize Sinkhorn potential/scaling f_u.

init_dual_b(ot_prob, lse_mode[, rng])

Initialize Sinkhorn potential/scaling g_v.

update(state, a, b)

Update the meta model with the dual objective.