ott.neural.methods.monge_gap.MongeGapEstimator#
- class ott.neural.methods.monge_gap.MongeGapEstimator(dim_data, model, optimizer=None, fitting_loss=None, regularizer=None, regularizer_strength=1.0, num_train_iters=10000, logging=False, valid_freq=500, rng=None)[source]#
Mapping estimator between probability measures.
It estimates a map \(T\) by minimizing the loss:
\[\text{min}_{\theta}\; \Delta(T_\theta \sharp \mu, \theta) + \lambda R(T_\theta \sharp \rho, \rho)\]where \(\Delta\) is a fitting loss and \(R\) is a regularizer. \(\Delta\) allows to fit the marginal constraint, i.e. transport \(\mu\) to \(\nu\) via \(T\), while \(R\) is a regularizer imposing an inductive bias on the learned map. The regularizer in this case is a function used to compute a metric between two sets of points.
For instance, \(\Delta\) can be the
sinkhorn_divergence()
and \(R\) themonge_gap_from_samples()
[Uscidda and Cuturi, 2023] for a given cost function \(c\). In that case, it estimates a \(c\)-OT map, i.e. a map \(T\) optimal for the Monge problem induced by \(c\).- Parameters:
dim_data (
int
) – input dimensionality of data required for network init.model (
BasePotential
) – network architecture for map \(T\).optimizer (
Union
[Array
,ndarray
,bool
,number
,Iterable
[ArrayTree],Mapping
[Any
, ArrayTree],None
]) – optimizer function for map \(T\).fitting_loss (
Optional
[Callable
[[Array
,Array
],Tuple
[float
,Optional
[Any
]]]]) – function that outputs a fitting loss \(\Delta\) between two families of points, as well as any log object.regularizer (
Optional
[Callable
[[Array
,Array
],Tuple
[float
,Optional
[Any
]]]]) – function that outputs a score from two families of points, here assumed to be of the same size, as well as any log object.regularizer_strength (
Union
[float
,Sequence
[float
]]) – strength of theregularizer
.num_train_iters (
int
) – number of total training iterations.logging (
bool
) – option to return logs.valid_freq (
int
) – frequency with training and validation are logged.rng (
Optional
[Array
]) – random key used for seeding for network initializations.
Methods
setup
(dim_data, neural_net, optimizer)Setup all components required to train the network.
train_map_estimator
(trainloader_source, ...)Training loop.
Attributes
Fitting loss to fit the marginal constraint.
Regularizer added to the fitting loss.