ott.neural.layers.PosDefPotentials

Contents

ott.neural.layers.PosDefPotentials#

class ott.neural.layers.PosDefPotentials(num_potentials, rank=0, rectifier_fn=<jax._src.custom_derivatives.custom_jvp object>, use_linear=True, use_bias=True, kernel_lr_init=<function <lambda>>, kernel_diag_init=<function ones>, kernel_linear_init=<function <lambda>>, bias_init=<function zeros>, precision=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

\(\frac{1}{2} x^T (A_i A_i^T + \text{Diag}(d_i)) x + b_i^T x^2 + c_i\) potentials.

This class implements a layer that takes (batched) d-dimensional vectors x in, to output a num_potentials-dimensional vector. Each of the entries in that output is a positive definite quadratic form evaluated at x; each of these quadratic terms is parameterized as a low-rank plus diagonal matrix. The low-rank term is parameterized as \(A_i A_i^T\), where each of these matrices is of size (rank, d). Taken together, these matrices form a tensor (num_potentials, rank, d). The diagonal terms \(d_i\) form a (num_potentials, d) matrix of positive values; the linear terms \(b_i\) form a (num_potentials, d) matrix. Finally, the \(c_i\) are contained in a vector of size (num_potentials,).

Parameters:

Methods

apply(variables, *args[, rngs, method, ...])

Applies a module method to variables and returns output and modified variables.

bias_init(shape[, dtype])

An initializer that returns a constant array full of zeros.

clone(*[, parent, _deep_clone, _reset_names])

Creates a clone of this Module, with optionally updated arguments.

copy(*[, parent, name])

Creates a copy of this Module, with optionally updated arguments.

get_variable(col, name[, default])

Retrieves the value of a Variable.

has_rng(name)

Returns true if a PRNGSequence with name name exists.

has_variable(col, name)

Checks if a variable of given collection and name exists in this Module.

init(rngs, *args[, method, mutable, ...])

Initializes a module method with variables and returns modified variables.

init_from_samples(source, target, **kwargs)

Initialize the layer using Gaussian approximation [Bunne et al., 2022].

init_with_output(rngs, *args[, method, ...])

Initializes a module method with variables and returns output and modified variables.

is_initializing()

Returns True if running under self.init(...) or nn.init(...)().

is_mutable_collection(col)

Returns true if the collection col is mutable.

kernel_diag_init(shape[, dtype])

An initializer that returns a constant array full of ones.

kernel_linear_init(**k)

kernel_lr_init(**k)

lazy_init(rngs, *args[, method, mutable])

Initializes a module without computing on an actual input.

make_rng([name])

Returns a new RNG key from a given RNG sequence for this Module.

param(name, init_fn, *init_args[, unbox])

Declares and returns a parameter in this Module.

perturb(name, value[, collection])

Add an zero-value variable ('perturbation') to the intermediate value.

put_variable(col, name, value)

Updates the value of the given variable if it is mutable, or an error otherwise.

setup()

Initializes a Module lazily (similar to a lazy __init__).

sow(col, name, value[, reduce_fn, init_fn])

Stores a value in a collection.

unbind()

Returns an unbound copy of a Module and its variables.

variable(col, name[, init_fn, unbox])

Declares and returns a variable in this Module.

Attributes

name

parent

path

precision

rank

rectifier_fn(x)

scope

use_bias

use_linear

variables

Returns the variables in this module.

num_potentials