ott.neural.networks.icnn.ICNN

Contents

ott.neural.networks.icnn.ICNN#

class ott.neural.networks.icnn.ICNN(dim_hidden, *, input_dim, output_dim=1, rectifier_fn=<PjitFunction of <function softplus>>, act_fn=<jax._src.custom_derivatives.custom_jvp object>, wx_inject=True, use_bias=True, use_softmax=False, use_sinkhorn=False, pos_def_rank=0, principled_init=False, kernel_init=<function variance_scaling.<locals>.init>, wz_kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, rngs)[source]#

Input convex neural network (ICNN).

Implementation of input convex neural networks as introduced in [Amos et al., 2017] with flexible input re-injection, multiple rectifier options, and optional positive-definite quadratic potentials [Vesseron and Cuturi, 2024].

The network computes a convex function \(f \colon \mathbb{R}^d \to \mathbb{R}^k\) where convexity holds component-wise when \(k > 1\).

Architecture:

z_0 = act(W_x0 @ x + b_0)
z_i = act(W_z_i @ z_{i-1} + W_x_i @ x + b_i)  # wx_inject controls W_x_i
z_N = W_z_N @ z_{N-1} + W_x_N @ x + b_N  # final layer is linear (no act)
out = z_N + pos_def_potentials(x)  # optional

Biases b_i (on W_x0 and the W_z layers) are included only when use_bias=True; the W_x input-injection terms are always bias-free. Convexity is enforced by requiring W_z_i >= 0 (via rectifier) and using convex activation functions. The final layer applies no activation; convexity is preserved since a non-negatively weighted sum (plus bias) of convex features is convex.

Parameters:
  • dim_hidden (Sequence[int]) – Sequence of hidden layer sizes. The output dimension defaults to 1 (scalar potential); set output_dim for vector output.

  • input_dim (int) – Dimension of the input x.

  • output_dim (int) – Output dimension. Defaults to 1 (scalar convex function). When > 1, each output component is convex in the input.

  • rectifier_fn (Callable[[Array], Array]) – Function applied to W_z kernels to enforce non-negativity. The default is softplus().

  • act_fn (Callable[[Array], Array]) – Activation function (must be convex for the network to be convex). The default is relu().

  • wx_inject (Union[bool, Tuple[bool, ...], int]) – Controls input re-injection at intermediate layers.

  • use_bias (bool) – Whether to use bias terms.

  • use_softmax (bool) – If True, the W_z PositiveDense layers use column-wise softmax normalization instead of rectifier_fn.

  • use_sinkhorn (bool) – If True, the W_z PositiveDense layers use Sinkhorn normalization instead of rectifier_fn.

  • pos_def_rank (int) – Rank of optional PosDefPotentials term. Set to 0 to disable (default).

  • principled_init (bool) – If True, override wz_kernel_init and the W_z bias initializer with the principled ICNN initialization of [Hoedt and Klambauer, 2023], which controls correlation and variance propagation through layers with positive weights.

  • kernel_init (Union[Initializer, Callable[..., Any]]) – Initializer for W_x (unrestricted) weights.

  • wz_kernel_init (Union[Initializer, Callable[..., Any]]) – Initializer for W_z (positive) weights. Ignored when principled_init=True.

  • bias_init (Union[Initializer, Callable[..., Any]]) – Initializer for biases.

  • rngs (Rngs) – Random number generators.

  • args (Any)

  • kwargs (Any)

Return type:

Any

Methods

eval(**attributes)

Sets the Module to evaluation mode.

gradient(x)

Gradient of the convex potential w.r.t.

iter_children()

Warning: this method is method is deprecated; use iter_children() instead.

iter_modules()

Warning: this method is method is deprecated; use iter_modules() instead.

perturb(name, value[, variable_type])

Extract gradients of intermediate values during training.

sow(variable_type, name, value[, reduce_fn, ...])

Store intermediate values during module execution for later extraction.

train(**attributes)

Sets the Module to training mode.

Attributes

is_potential

Whether this module represents a potential (True) or vector field.