ott.initializers.nn.initializers.MetaInitializer#

class ott.initializers.nn.initializers.MetaInitializer(geom, meta_model=None, opt=GradientTransformation(init=<function chain.<locals>.init_fn>, update=<function chain.<locals>.update_fn>), rng=Array([0, 0], dtype=uint32), state=None)[source]#

Meta OT Initializer with a fixed geometry .

This initializer consists of a predictive model that outputs the $$f$$ duals to solve the entropy-regularized OT problem given input probability weights a and b, and a given (assumed to be fixed) geometry geom.

The model’s parameters are learned using a training set of OT instances (multiple pairs of probability weights), that assume the same geometry geom is used throughout, both for training and evaluation. The meta model defaults to the MLP in MetaMLP and, with batched problem instances passed into update().

Parameters

Examples

The following code shows a simple example of using update to train the model, where a and b are the weights of the measures and geom is the fixed geometry.

meta_initializer = init_lib.MetaInitializer(geom)
while training():
a, b = sample_batch()
loss, init_f, meta_initializer.state = meta_initializer.update(
meta_initializer.state, a=a, b=b
)


Methods

 init_dual_a(ot_prob, lse_mode) Initialize Sinkhorn potential/scaling f_u. init_dual_b(ot_prob, lse_mode) Initialize Sinkhorn potential/scaling g_v. update(state, a, b) Update the meta model with the dual objective.