ott.neural.networks.icnn.ICNN#
- class ott.neural.networks.icnn.ICNN(dim_hidden, *, input_dim, output_dim=1, rectifier_fn=<PjitFunction of <function softplus>>, act_fn=<jax._src.custom_derivatives.custom_jvp object>, wx_inject=True, use_bias=True, use_softmax=False, use_sinkhorn=False, pos_def_rank=0, principled_init=False, kernel_init=<function variance_scaling.<locals>.init>, wz_kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, rngs)[source]#
Input convex neural network (ICNN).
Implementation of input convex neural networks as introduced in [Amos et al., 2017] with flexible input re-injection, multiple rectifier options, and optional positive-definite quadratic potentials [Vesseron and Cuturi, 2024].
The network computes a convex function \(f \colon \mathbb{R}^d \to \mathbb{R}^k\) where convexity holds component-wise when \(k > 1\).
Architecture:
z_0 = act(W_x0 @ x + b_0) z_i = act(W_z_i @ z_{i-1} + W_x_i @ x + b_i) # wx_inject controls W_x_i z_N = W_z_N @ z_{N-1} + W_x_N @ x + b_N # final layer is linear (no act) out = z_N + pos_def_potentials(x) # optional
Biases
b_i(onW_x0and theW_zlayers) are included only whenuse_bias=True; theW_xinput-injection terms are always bias-free. Convexity is enforced by requiring W_z_i >= 0 (via rectifier) and using convex activation functions. The final layer applies no activation; convexity is preserved since a non-negatively weighted sum (plus bias) of convex features is convex.- Parameters:
dim_hidden (
Sequence[int]) – Sequence of hidden layer sizes. The output dimension defaults to 1 (scalar potential); setoutput_dimfor vector output.input_dim (
int) – Dimension of the inputx.output_dim (
int) – Output dimension. Defaults to 1 (scalar convex function). When > 1, each output component is convex in the input.rectifier_fn (
Callable[[Array],Array]) – Function applied to W_z kernels to enforce non-negativity. The default issoftplus().act_fn (
Callable[[Array],Array]) – Activation function (must be convex for the network to be convex). The default isrelu().wx_inject (
Union[bool,Tuple[bool,...],int]) – Controls input re-injection at intermediate layers.use_bias (
bool) – Whether to use bias terms.use_softmax (
bool) – If True, theW_zPositiveDenselayers use column-wise softmax normalization instead ofrectifier_fn.use_sinkhorn (
bool) – If True, theW_zPositiveDenselayers use Sinkhorn normalization instead ofrectifier_fn.pos_def_rank (
int) – Rank of optional PosDefPotentials term. Set to 0 to disable (default).principled_init (
bool) – If True, overridewz_kernel_initand the W_z bias initializer with the principled ICNN initialization of [Hoedt and Klambauer, 2023], which controls correlation and variance propagation through layers with positive weights.kernel_init (
Union[Initializer,Callable[...,Any]]) – Initializer for W_x (unrestricted) weights.wz_kernel_init (
Union[Initializer,Callable[...,Any]]) – Initializer for W_z (positive) weights. Ignored whenprincipled_init=True.bias_init (
Union[Initializer,Callable[...,Any]]) – Initializer for biases.rngs (
Rngs) – Random number generators.args (Any)
kwargs (Any)
- Return type:
Any
Methods
eval(**attributes)Sets the Module to evaluation mode.
gradient(x)Gradient of the convex potential w.r.t.
Warning: this method is method is deprecated; use
iter_children()instead.Warning: this method is method is deprecated; use
iter_modules()instead.perturb(name, value[, variable_type])Extract gradients of intermediate values during training.
sow(variable_type, name, value[, reduce_fn, ...])Store intermediate values during module execution for later extraction.
train(**attributes)Sets the Module to training mode.
Attributes
Whether this module represents a potential (True) or vector field.