ott.neural.solvers.map_estimator.MapEstimator

Contents

ott.neural.solvers.map_estimator.MapEstimator#

class ott.neural.solvers.map_estimator.MapEstimator(dim_data, model, optimizer=None, fitting_loss=None, regularizer=None, regularizer_strength=1.0, num_train_iters=10000, logging=False, valid_freq=500, rng=None)[source]#

Mapping estimator between probability measures.

It estimates a map \(T\) by minimizing the loss:

\[\text{min}_{\theta}\; \Delta(T_\theta \sharp \mu, \theta) + \lambda R(T_\theta \sharp \rho, \rho)\]

where \(\Delta\) is a fitting loss and \(R\) is a regularizer. \(\Delta\) allows to fit the marginal constraint, i.e. transport \(\mu\) to \(\nu\) via \(T\), while \(R\) is a regularizer imposing an inductive bias on the learned map. The regularizer in this case is a function used to compute a metric between two sets of points.

For instance, \(\Delta\) can be the sinkhorn_divergence() and \(R\) the monge_gap_from_samples() [Uscidda and Cuturi, 2023] for a given cost function \(c\). In that case, it estimates a \(c\)-OT map, i.e. a map \(T\) optimal for the Monge problem induced by \(c\).

Parameters:

Methods

setup(dim_data, neural_net, optimizer)

Setup all components required to train the network.

train_map_estimator(trainloader_source, ...)

Training loop.

Attributes

fitting_loss

Fitting loss to fit the marginal constraint.

regularizer

Regularizer added to the fitting loss.