ott.neural.methods.monge_gap.MongeGapEstimator#

class ott.neural.methods.monge_gap.MongeGapEstimator(dim_data, model, optimizer=None, fitting_loss=None, regularizer=None, regularizer_strength=1.0, num_train_iters=10000, logging=False, valid_freq=500, rng=None)[source]#

Mapping estimator between probability measures.

It estimates a map $$T$$ by minimizing the loss:

$\text{min}_{\theta}\; \Delta(T_\theta \sharp \mu, \theta) + \lambda R(T_\theta \sharp \rho, \rho)$

where $$\Delta$$ is a fitting loss and $$R$$ is a regularizer. $$\Delta$$ allows to fit the marginal constraint, i.e. transport $$\mu$$ to $$\nu$$ via $$T$$, while $$R$$ is a regularizer imposing an inductive bias on the learned map. The regularizer in this case is a function used to compute a metric between two sets of points.

For instance, $$\Delta$$ can be the sinkhorn_divergence() and $$R$$ the monge_gap_from_samples() for a given cost function $$c$$. In that case, it estimates a $$c$$-OT map, i.e. a map $$T$$ optimal for the Monge problem induced by $$c$$.

Parameters:
• dim_data (int) – input dimensionality of data required for network init.

• model (BasePotential) – network architecture for map $$T$$.

• optimizer () – optimizer function for map $$T$$.

• fitting_loss () – function that outputs a fitting loss $$\Delta$$ between two families of points, as well as any log object.

• regularizer () – function that outputs a score from two families of points, here assumed to be of the same size, as well as any log object.

• regularizer_strength () – strength of the regularizer.

• num_train_iters (int) – number of total training iterations.

• logging (bool) – option to return logs.

• valid_freq (int) – frequency with training and validation are logged.

• rng () – random key used for seeding for network initializations.

Methods

 setup(dim_data, neural_net, optimizer) Setup all components required to train the network. train_map_estimator(trainloader_source, ...) Training loop.

Attributes

 fitting_loss Fitting loss to fit the marginal constraint. regularizer Regularizer added to the fitting loss.