ott.neural.methods.expectile_neural_dual.ExpectileNeuralDual#
- class ott.neural.methods.expectile_neural_dual.ExpectileNeuralDual(dim_data, neural_f=None, neural_g=None, optimizer_f=None, optimizer_g=None, cost_fn=None, is_bidirectional=True, use_dot_product=False, expectile=0.99, expectile_loss_coef=1.0, num_train_iters=20000, valid_freq=1000, log_freq=1000, logging=False, rng=None)[source]#
Expectile-regularized Neural Optimal Transport (ENOT) [Buzun et al., 2024].
It solves the dual optimal transport problem for a specified cost function \(c(x, y)\) between two measures \(\alpha\) and \(\beta\) in \(d\)-dimensional Euclidean space with additional regularization on dual Kantorovich potentials. The expectile regularization enforces binding conditions on the learning dual potentials \(f\) and \(g\). The main optimization objective is
\[\sup_{g \in L_1(\beta)} \inf_{T: \, R^d \to R^d} \big[ \mathbb{E}_{\alpha}[c(x, T(x))] + \mathbb{E}_{\beta} [g(y)] - \mathbb{E}_{\alpha} [g(T(x))] \big],\]where \(T(x)\) is the transport mapping from \(\alpha\) to \(\beta\) expressed through \(\nabla f(x)\). The explicit formula depends on the cost function and
is_bidirectional
training option. The regularization term is\[\mathbb{E} \mathcal{L}_{\tau} \big( c(x, T(x)) - g(T(x)) - c(x, y) + g(y) \big),\]where \(\mathcal{L}_{\tau}\) is the least asymmetrically weighted squares loss from expectile regression.
The potentials for
neural_f
andneural_g
canboth provide the values of the potentials \(f\) and \(g\), or
when parameter
is_bidirectional=False
,neural_f
provides the gradient \(\nabla f\) for mapping \(T\).
- Parameters:
dim_data (
int
) – Input dimensionality of data required for network init.neural_f (
Optional
[Module
]) – Network architecture for potential :math:f or its gradient \(\nabla f\).neural_g (
Optional
[Module
]) – Network architecture for potential \(g\).optimizer_f (
Optional
[GradientTransformation
]) – Optimizer function for potential \(f\).optimizer_g (
Optional
[GradientTransformation
]) – Optimizer function for potential \(g\).cost_fn (
Optional
[CostFn
]) – Cost function of the OT problem.is_bidirectional (
bool
) – Alternate between updating the forward and backward directions. Inspired from [Jacobs and Léger, 2020].use_dot_product (
bool
) – Whether the duals solve the problem in correlation form.expectile (
float
) – Parameter of the expectile loss (\(\tau\)). Suggested values range is \([0.9, 1.0)\).expectile_loss_coef (
float
) – Expectile loss weight. Suggested values range is \([0.3, 1.0]\).num_train_iters (
int
) – Number of total training iterations.valid_freq (
int
) – Frequency with which model is validated.log_freq (
int
) – Frequency with training and validation are logged.logging (
bool
) – Option to return logs.rng (
Optional
[Array
]) – Random key used for seeding for network initializations.
Methods
Return the Kantorovich dual potentials from the trained potentials.
train_fn
(trainloader_source, ...[, callback])Training and validation.