ott.neural.networks.icnn.KeyNet

Contents

ott.neural.networks.icnn.KeyNet#

class ott.neural.networks.icnn.KeyNet(dim_hidden, *, input_dim, output_dim=None, num_outputs=None, resnet=False, act_fn=<jax._src.custom_derivatives.custom_jvp object>, wx_inject=True, use_bias=True, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, final_layer_scale=None, rngs)[source]#

Vector-output network with ICNN-like architecture.

Introduced in [Olausson et al., 2026] to amortize maximum inner product search with learned support functions. Unlike ICNN, which outputs a scalar convex function and requires autodiff to compute gradients, KeyNet directly outputs vectors. The architecture mirrors ICNN but drops the non-negativity constraint on the layer-to-layer weights \(W_z\).

Architecture:

z_0 = act(W_x0 @ x + b_0)
z_i = act(W_z_i @ z_{i-1} + W_x_i @ x + b_i)  # wx_inject controls W_x_i
z_N = W_z_N @ z_{N-1} + W_x_N @ x + b_N  # final layer is linear (no act)
out = x + z_N if resnet else z_N

Biases \(b_i\) (on W_x0 and the W_z layers) are included only when use_bias=True; the W_x input-injection terms are always bias-free. The final layer applies no activation, so the output vector is unconstrained (e.g. signed gradients). When resnet=True the output is \(x + F(x)\), i.e. the model learns a correction to the input query. The scalar potential is recovered as \(f(x) = \langle \mathrm{KeyNet}(x), x \rangle\).

Parameters:
  • dim_hidden (Sequence[int]) – Sequence of hidden layer sizes.

  • input_dim (int) – Dimension of the input x.

  • output_dim (Optional[int]) – Output vector dimension. Defaults to input_dim. Typically, equals the input dimension for gradient-of-potential interpretation.

  • num_outputs (Optional[int]) – Number of output vectors.

  • resnet (bool) – If True, output \(x + F(x)\) instead of \(F(x)\).

  • act_fn (Callable[[Array], Array]) – Activation function.

  • wx_inject (Union[bool, Tuple[bool, ...], int]) – Controls input re-injection pattern.

  • use_bias (bool) – Whether to use bias terms.

  • kernel_init (Union[Initializer, Callable[..., Any]]) – Initializer for all weights.

  • bias_init (Union[Initializer, Callable[..., Any]]) – Initializer for biases.

  • final_layer_scale (Optional[float]) – Scale for final layer init. Defaults to 0.01 for resnet mode (small initial corrections), 1.0 otherwise.

  • rngs (Rngs) – Random number generators.

  • args (Any)

  • kwargs (Any)

Return type:

Any

Methods

eval(**attributes)

Sets the Module to evaluation mode.

gradient(x)

Compute the vector output (predicted gradient / key).

iter_children()

Warning: this method is method is deprecated; use iter_children() instead.

iter_modules()

Warning: this method is method is deprecated; use iter_modules() instead.

perturb(name, value[, variable_type])

Extract gradients of intermediate values during training.

sow(variable_type, name, value[, reduce_fn, ...])

Store intermediate values during module execution for later extraction.

train(**attributes)

Sets the Module to training mode.

Attributes

is_potential

KeyNet models a potential via \(f(x) = \langle g(x), x \rangle\).