ott.neural.networks.icnn.KeyNet#
- class ott.neural.networks.icnn.KeyNet(dim_hidden, *, input_dim, output_dim=None, num_outputs=None, resnet=False, act_fn=<jax._src.custom_derivatives.custom_jvp object>, wx_inject=True, use_bias=True, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, final_layer_scale=None, rngs)[source]#
Vector-output network with ICNN-like architecture.
Introduced in [Olausson et al., 2026] to amortize maximum inner product search with learned support functions. Unlike
ICNN, which outputs a scalar convex function and requires autodiff to compute gradients, KeyNet directly outputs vectors. The architecture mirrorsICNNbut drops the non-negativity constraint on the layer-to-layer weights \(W_z\).Architecture:
z_0 = act(W_x0 @ x + b_0) z_i = act(W_z_i @ z_{i-1} + W_x_i @ x + b_i) # wx_inject controls W_x_i z_N = W_z_N @ z_{N-1} + W_x_N @ x + b_N # final layer is linear (no act) out = x + z_N if resnet else z_N
Biases \(b_i\) (on
W_x0and theW_zlayers) are included only whenuse_bias=True; theW_xinput-injection terms are always bias-free. The final layer applies no activation, so the output vector is unconstrained (e.g. signed gradients). Whenresnet=Truethe output is \(x + F(x)\), i.e. the model learns a correction to the input query. The scalar potential is recovered as \(f(x) = \langle \mathrm{KeyNet}(x), x \rangle\).- Parameters:
dim_hidden (
Sequence[int]) – Sequence of hidden layer sizes.input_dim (
int) – Dimension of the inputx.output_dim (
Optional[int]) – Output vector dimension. Defaults toinput_dim. Typically, equals the input dimension for gradient-of-potential interpretation.resnet (
bool) – If True, output \(x + F(x)\) instead of \(F(x)\).wx_inject (
Union[bool,Tuple[bool,...],int]) – Controls input re-injection pattern.use_bias (
bool) – Whether to use bias terms.kernel_init (
Union[Initializer,Callable[...,Any]]) – Initializer for all weights.bias_init (
Union[Initializer,Callable[...,Any]]) – Initializer for biases.final_layer_scale (
Optional[float]) – Scale for final layer init. Defaults to 0.01 for resnet mode (small initial corrections), 1.0 otherwise.rngs (
Rngs) – Random number generators.args (Any)
kwargs (Any)
- Return type:
Any
Methods
eval(**attributes)Sets the Module to evaluation mode.
gradient(x)Compute the vector output (predicted gradient / key).
Warning: this method is method is deprecated; use
iter_children()instead.Warning: this method is method is deprecated; use
iter_modules()instead.perturb(name, value[, variable_type])Extract gradients of intermediate values during training.
sow(variable_type, name, value[, reduce_fn, ...])Store intermediate values during module execution for later extraction.
train(**attributes)Sets the Module to training mode.
Attributes
KeyNet models a potential via \(f(x) = \langle g(x), x \rangle\).