ott.neural.networks.layers.posdef.PosDefPotentials#
- class ott.neural.networks.layers.posdef.PosDefPotentials(num_potentials, rank=0, rectifier_fn=<jax._src.custom_derivatives.custom_jvp object>, use_linear=True, use_bias=True, kernel_lr_init=<function <lambda>>, kernel_diag_init=<function ones>, kernel_linear_init=<function <lambda>>, bias_init=<function zeros>, precision=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
- \(\frac{1}{2} x^T (A_i A_i^T + \text{Diag}(d_i)) x + b_i^T x^2 + c_i\)
potentials.
This class implements a layer that takes (batched)
d
-dimensional vectorsx
in, to output anum_potentials
-dimensional vector. Each of the entries in that output is a positive definite quadratic form evaluated atx
; each of these quadratic terms is parameterized as a low-rank plus diagonal matrix. The low-rank term is parameterized as \(A_i A_i^T\), where each of these matrices is of size(rank, d)
. Taken together, these matrices form a tensor(num_potentials, rank, d)
. The diagonal terms \(d_i\) form a(num_potentials, d)
matrix of positive values; the linear terms \(b_i\) form a(num_potentials, d)
matrix. Finally, the \(c_i\) are contained in a vector of size(num_potentials,)
.- Parameters:
num_potentials (
int
) – Dimension of the output.rank (
int
) – Rank of the matrices \(A_i\) used as low-rank factors for the quadratic potentials.rectifier_fn (
Callable
[[Array
],Array
]) – Rectifier function to ensure non-negativity of the diagonals \(d_i\). The default isrelu()
.use_linear (
bool
) – Whether to add a linear layers \(b_i\) to the outputs.use_bias (
bool
) – Whether to add biases \(c_i\) to the outputs.kernel_lr_init (
Callable
[[Array
,Tuple
[int
,...
],Any
],Array
]) – Initializer for the matrices \(A_i\) of the quadratic potentials whenrank > 0
. The default islecun_normal()
.kernel_diag_init (
Callable
[[Array
,Tuple
[int
,...
],Any
],Array
]) – Initializer for the diagonals \(d_i\). The default isones()
.kernel_linear_init (
Callable
[[Array
,Tuple
[int
,...
],Any
],Array
]) – Initializer for the linear layers \(b_i\). The default islecun_normal()
.bias_init (
Callable
[[Array
,Tuple
[int
,...
],Any
],Array
]) – Initializer for the bias. The default iszeros()
.precision (
Optional
[Precision
]) – Numerical precision of the computation.
Methods
apply
(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
bias_init
(shape[, dtype])An initializer that returns a constant array full of zeros.
clone
(*[, parent, _deep_clone, _reset_names])Creates a clone of this Module, with optionally updated arguments.
copy
(*[, parent, name])Creates a copy of this Module, with optionally updated arguments.
get_variable
(col, name[, default])Retrieves the value of a Variable.
has_rng
(name)Returns true if a PRNGSequence with name
name
exists.has_variable
(col, name)Checks if a variable of given collection and name exists in this Module.
init
(rngs, *args[, method, mutable, ...])Initializes a module method with variables and returns modified variables.
init_from_samples
(source, target, **kwargs)Initialize the layer using Gaussian approximation [Bunne et al., 2022].
init_with_output
(rngs, *args[, method, ...])Initializes a module method with variables and returns output and modified variables.
Returns True if running under self.init(...) or nn.init(...)().
Returns true if the collection
col
is mutable.kernel_diag_init
(shape[, dtype])An initializer that returns a constant array full of ones.
kernel_linear_init
(**k)kernel_lr_init
(**k)lazy_init
(rngs, *args[, method, mutable])Initializes a module without computing on an actual input.
make_rng
([name])Returns a new RNG key from a given RNG sequence for this Module.
param
(name, init_fn, *init_args[, unbox])Declares and returns a parameter in this Module.
perturb
(name, value[, collection])Add an zero-value variable ('perturbation') to the intermediate value.
put_variable
(col, name, value)Updates the value of the given variable if it is mutable, or an error otherwise.
setup
()Initializes a Module lazily (similar to a lazy
__init__
).sow
(col, name, value[, reduce_fn, init_fn])Stores a value in a collection.
unbind
()Returns an unbound copy of a Module and its variables.
variable
(col, name[, init_fn, unbox])Declares and returns a variable in this Module.
Attributes
rectifier_fn
(x)Returns the variables in this module.