ott.neural.models.MetaInitializer

Contents

ott.neural.models.MetaInitializer#

class ott.neural.models.MetaInitializer(geom, meta_model, opt=(<function chain.<locals>.init_fn>, <function chain.<locals>.update_fn>), rng=None, state=None)[source]#

Meta OT Initializer with a fixed geometry [Amos et al., 2022].

This initializer consists of a predictive model that outputs the \(f\) duals to solve the entropy-regularized OT problem given input probability weights a and b, and a given (assumed to be fixed) geometry geom.

The model’s parameters are learned using a training set of OT instances (multiple pairs of probability weights), that assume the same geometry geom is used throughout, both for training and evaluation.

Parameters:

Examples

The following code shows a simple example of using update to train the model, where a and b are the weights of the measures and geom is the fixed geometry.

meta_initializer = init_lib.MetaInitializer(geom)
while training():
  a, b = sample_batch()
  loss, init_f, meta_initializer.state = meta_initializer.update(
    meta_initializer.state, a=a, b=b
  )

Methods

init_dual_a(ot_prob, lse_mode[, rng])

Initialize Sinkhorn potential/scaling f_u.

init_dual_b(ot_prob, lse_mode[, rng])

Initialize Sinkhorn potential/scaling g_v.

update(state, a, b)

Update the meta model with the dual objective.