ott.neural.methods.expectile_neural_dual.ExpectileNeuralDual

ott.neural.methods.expectile_neural_dual.ExpectileNeuralDual#

class ott.neural.methods.expectile_neural_dual.ExpectileNeuralDual(dim_data, neural_f=None, neural_g=None, optimizer_f=None, optimizer_g=None, cost_fn=None, expectile=0.99, expectile_loss_coef=1.0, num_train_iters=20000, valid_freq=1000, log_freq=1000, logging=False, rng=None)[source]#

Expectile-regularized Neural Optimal Transport (ENOT) [Buzun et al., 2024].

It solves the dual optimal transport problem for a specified cost function \(c(x, y)\) between two measures \(\alpha\) and \(\beta\) in \(d\)-dimensional Euclidean space with additional regularization on dual Kantorovich potentials. The expectile regularization enforces binding conditions on the learning dual potentials \(f\) and \(g\). The main optimization objective is

\[\sup_{g \in L_1(\beta)} \inf_{T: \, R^d \to R^d} \big[ \mathbb{E}_{\alpha}[c(x, T(x))] + \mathbb{E}_{\beta} [g(y)] - \mathbb{E}_{\alpha} [g(T(x))] \big],\]

where \(T(x)\) is the transport mapping from \(\alpha\) to \(\beta\) expressed through \(\nabla f(x)\). The explicit formula for that transport depends on the twist operator of the cost function. The regularization term is

\[\mathbb{E} \mathcal{L}_{\tau} \big( c(x, T(x)) - g(T(x)) - c(x, y) + g(y) \big),\]

where \(\mathcal{L}_{\tau}\) is the least asymmetrically weighted squares loss from expectile regression.

Parameters:
  • dim_data (int) – Input dimensionality of data required for network init.

  • neural_f (Optional[Module]) – Network architecture for potential :math:f or its gradient \(\nabla f\).

  • neural_g (Optional[Module]) – Network architecture for potential \(g\).

  • optimizer_f (Optional[GradientTransformation]) – Optimizer function for potential \(f\).

  • optimizer_g (Optional[GradientTransformation]) – Optimizer function for potential \(g\).

  • cost_fn (Optional[TICost]) – Translation invariant cost function.

  • expectile (float) – Parameter of the expectile loss (\(\tau\)). Suggested values range is \([0.9, 1.0)\).

  • expectile_loss_coef (float) – Expectile loss weight. Suggested values range is \([0.3, 1.0]\).

  • num_train_iters (int) – Number of total training iterations.

  • valid_freq (int) – Frequency with which model is validated.

  • log_freq (int) – Frequency with training and validation are logged.

  • logging (bool) – Option to return logs.

  • rng (Optional[Array]) – Random key used for seeding for network initializations.

Methods

to_dual_potentials()

Return the Kantorovich dual potentials from the trained potentials.

train_fn(trainloader_source, ...[, callback])

Training and validation.