ott.neural.methods.expectile_neural_dual.ExpectileNeuralDual

ott.neural.methods.expectile_neural_dual.ExpectileNeuralDual#

class ott.neural.methods.expectile_neural_dual.ExpectileNeuralDual(dim_data, neural_f=None, neural_g=None, optimizer_f=None, optimizer_g=None, cost_fn=None, is_bidirectional=True, use_dot_product=False, expectile=0.99, expectile_loss_coef=1.0, num_train_iters=20000, valid_freq=1000, log_freq=1000, logging=False, rng=None)[source]#

Expectile-regularized Neural Optimal Transport (ENOT) [Buzun et al., 2024].

It solves the dual optimal transport problem for a specified cost function \(c(x, y)\) between two measures \(\alpha\) and \(\beta\) in \(d\)-dimensional Euclidean space with additional regularization on dual Kantorovich potentials. The expectile regularization enforces binding conditions on the learning dual potentials \(f\) and \(g\). The main optimization objective is

\[\sup_{g \in L_1(\beta)} \inf_{T: \, R^d \to R^d} \big[ \mathbb{E}_{\alpha}[c(x, T(x))] + \mathbb{E}_{\beta} [g(y)] - \mathbb{E}_{\alpha} [g(T(x))] \big],\]

where \(T(x)\) is the transport mapping from \(\alpha\) to \(\beta\) expressed through \(\nabla f(x)\). The explicit formula depends on the cost function and is_bidirectional training option. The regularization term is

\[\mathbb{E} \mathcal{L}_{\tau} \big( c(x, T(x)) - g(T(x)) - c(x, y) + g(y) \big),\]

where \(\mathcal{L}_{\tau}\) is the least asymmetrically weighted squares loss from expectile regression.

The potentials for neural_f and neural_g can

  1. both provide the values of the potentials \(f\) and \(g\), or

  2. when parameter is_bidirectional=False, neural_f provides the gradient \(\nabla f\) for mapping \(T\).

Parameters:
  • dim_data (int) – Input dimensionality of data required for network init.

  • neural_f (Optional[Module]) – Network architecture for potential :math:f or its gradient \(\nabla f\).

  • neural_g (Optional[Module]) – Network architecture for potential \(g\).

  • optimizer_f (Optional[GradientTransformation]) – Optimizer function for potential \(f\).

  • optimizer_g (Optional[GradientTransformation]) – Optimizer function for potential \(g\).

  • cost_fn (Optional[CostFn]) – Cost function of the OT problem.

  • is_bidirectional (bool) – Alternate between updating the forward and backward directions. Inspired from [Jacobs and Léger, 2020].

  • use_dot_product (bool) – Whether the duals solve the problem in correlation form.

  • expectile (float) – Parameter of the expectile loss (\(\tau\)). Suggested values range is \([0.9, 1.0)\).

  • expectile_loss_coef (float) – Expectile loss weight. Suggested values range is \([0.3, 1.0]\).

  • num_train_iters (int) – Number of total training iterations.

  • valid_freq (int) – Frequency with which model is validated.

  • log_freq (int) – Frequency with training and validation are logged.

  • logging (bool) – Option to return logs.

  • rng (Optional[Array]) – Random key used for seeding for network initializations.

Methods

to_dual_potentials()

Return the Kantorovich dual potentials from the trained potentials.

train_fn(trainloader_source, ...[, callback])

Training and validation.