ott.neural.networks.layers.posdef.PosDefPotentials

Contents

ott.neural.networks.layers.posdef.PosDefPotentials#

class ott.neural.networks.layers.posdef.PosDefPotentials(in_features, num_potentials, *, rank=1, use_linear=True, use_bias=True, kernel_diag_init=<function constant.<locals>.init>, kernel_lr_init=<function variance_scaling.<locals>.init>, kernel_linear_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, rectifier_fn=<PjitFunction of <function softplus>>, rngs)[source]#

Low-rank plus diagonal positive definite quadratic potentials.

Computes: sum_i 0.5 * x^T (A_i A_i^T + diag(d_i)) x + b_i^T x + c_i

This is used as an optional additive term in the ICNN to ensure strong convexity.

Parameters:
Return type:

Any

Methods

eval(**attributes)

Sets the Module to evaluation mode.

iter_children()

Warning: this method is method is deprecated; use iter_children() instead.

iter_modules()

Warning: this method is method is deprecated; use iter_modules() instead.

perturb(name, value[, variable_type])

Extract gradients of intermediate values during training.

sow(variable_type, name, value[, reduce_fn, ...])

Store intermediate values during module execution for later extraction.

train(**attributes)

Sets the Module to training mode.