ott.neural.methods.monge_gap.MongeGapEstimator#
- class ott.neural.methods.monge_gap.MongeGapEstimator(dim_data, model, optimizer=None, fitting_loss=None, regularizer=None, regularizer_strength=1.0, num_train_iters=10000, logging=False, valid_freq=500, rng=None)[source]#
Mapping estimator between probability measures.
It estimates a map \(T\) by minimizing the loss:
\[\text{min}_{\theta}\; \Delta(T_\theta \sharp \mu, \theta) + \lambda R(T_\theta \sharp \rho, \rho)\]where \(\Delta\) is a fitting loss and \(R\) is a regularizer. \(\Delta\) allows to fit the marginal constraint, i.e. transport \(\mu\) to \(\nu\) via \(T\), while \(R\) is a regularizer imposing an inductive bias on the learned map. The regularizer in this case is a function used to compute a metric between two sets of points.
For instance, \(\Delta\) can be the
sinkdiv()and \(R\) themonge_gap_from_samples()[Uscidda and Cuturi, 2023] for a given cost function \(c\). In that case, it estimates a \(c\)-OT map, i.e. a map \(T\) optimal for the Monge problem induced by \(c\).- Parameters:
dim_data (
int) – input dimensionality of data required for network init.model (
BasePotential) – network architecture for map \(T\).optimizer (
Union[Array,ndarray,bool,number,bool,int,float,complex,Iterable[ArrayTree],Mapping[Any, ArrayTree],None]) – optimizer function for map \(T\).fitting_loss (
Optional[Callable[[Array,Array],Tuple[float,Optional[Any]]]]) – function that outputs a fitting loss \(\Delta\) between two families of points, as well as any log object.regularizer (
Optional[Callable[[Array,Array],Tuple[float,Optional[Any]]]]) – function that outputs a score from two families of points, here assumed to be of the same size, as well as any log object.regularizer_strength (
Union[float,Sequence[float]]) – strength of theregularizer.num_train_iters (
int) – number of total training iterations.logging (
bool) – option to return logs.valid_freq (
int) – frequency with training and validation are logged.rng (
Optional[Array]) – random key used for seeding for network initializations.
Methods
setup(dim_data, neural_net, optimizer)Setup all components required to train the network.
train_map_estimator(trainloader_source, ...)Training loop.
Attributes
Fitting loss to fit the marginal constraint.
Regularizer added to the fitting loss.