ott.neural.networks.layers.posdef.PosDefPotentials#
- class ott.neural.networks.layers.posdef.PosDefPotentials(in_features, num_potentials, *, rank=1, use_linear=True, use_bias=True, kernel_diag_init=<function constant.<locals>.init>, kernel_lr_init=<function variance_scaling.<locals>.init>, kernel_linear_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, rectifier_fn=<PjitFunction of <function softplus>>, rngs)[source]#
Low-rank plus diagonal positive definite quadratic potentials.
Computes: sum_i 0.5 * x^T (A_i A_i^T + diag(d_i)) x + b_i^T x + c_i
This is used as an optional additive term in the ICNN to ensure strong convexity.
- Parameters:
in_features (
int) – Input dimension.num_potentials (
int) – Number of output potentials.rank (
int) – Rank of the low-rank factors A_i.use_linear (
bool) – Whether to include the linear term b^T x.use_bias (
bool) – Whether to include the scalar bias c.rngs (
Rngs) – Random number generators.kernel_diag_init (
Union[Initializer,Callable[...,Any]])kernel_lr_init (
Union[Initializer,Callable[...,Any]])kernel_linear_init (
Union[Initializer,Callable[...,Any]])bias_init (
Union[Initializer,Callable[...,Any]])args (Any)
kwargs (Any)
- Return type:
Any
Methods
eval(**attributes)Sets the Module to evaluation mode.
Warning: this method is method is deprecated; use
iter_children()instead.Warning: this method is method is deprecated; use
iter_modules()instead.perturb(name, value[, variable_type])Extract gradients of intermediate values during training.
sow(variable_type, name, value[, reduce_fn, ...])Store intermediate values during module execution for later extraction.
train(**attributes)Sets the Module to training mode.