ott.neural.models.ICNN

Contents

ott.neural.models.ICNN#

class ott.neural.models.ICNN(dim_data, dim_hidden, ranks=1, init_fn=<function <lambda>>, act_fn=<jax._src.custom_derivatives.custom_jvp object>, pos_weights=False, rectifier_fn=<jax._src.custom_derivatives.custom_jvp object>, gaussian_map_samples=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Input convex neural network (ICNN).

Implementation of input convex neural networks as introduced in [Amos et al., 2017] with initialization schemes proposed by [Bunne et al., 2022].

Parameters:
  • dim_data (int) – data dimensionality.

  • dim_hidden (Sequence[int]) – sequence specifying size of hidden dimensions. The output dimension of the last layer is 1 by default.

  • ranks (Union[int, Tuple[int, ...]]) – ranks of the matrices \(A_i\) used as low-rank factors for the quadratic potentials. If a sequence is passed, it must contain len(dim_hidden) + 2 elements, where the last 2 elements correspond to the ranks of the final layer with dimension 1 and the potentials, respectively.

  • init_fn (Callable[[Array, Tuple[int, ...], Any], Array]) – Initializer for the kernel weight matrices. The default is normal().

  • act_fn (Callable[[Array], Array]) – choice of activation function used in network architecture, needs to be convex. The default is relu().

  • pos_weights (bool) – Enforce positive weights with a projection. If False, the positive weights should be enforced with clipping or regularization in the loss.

  • rectifier_fn (Callable[[Array], Array]) – function to ensure the non negativity of the weights. The default is relu().

  • gaussian_map_samples (Optional[Tuple[Array, Array]]) – Tuple of source and target points, used to initialize the ICNN to mimic the linear Bures map that morphs the (Gaussian approximation) of the input measure to that of the target measure. If None, the identity initialization is used, and ICNN mimics half the squared Euclidean norm.

  • parent (Union[Type[Module], Scope, Type[_Sentinel], None]) –

  • name (Optional[str]) –

Methods

apply(variables, *args[, rngs, method, ...])

Applies a module method to variables and returns output and modified variables.

clone(*[, parent, _deep_clone, _reset_names])

Creates a clone of this Module, with optionally updated arguments.

copy(*[, parent, name])

Creates a copy of this Module, with optionally updated arguments.

create_train_state(rng, optimizer, input, ...)

Create initial training state.

get_variable(col, name[, default])

Retrieves the value of a Variable.

has_rng(name)

Returns true if a PRNGSequence with name name exists.

has_variable(col, name)

Checks if a variable of given collection and name exists in this Module.

init(rngs, *args[, method, mutable, ...])

Initializes a module method with variables and returns modified variables.

init_fn(**k)

init_with_output(rngs, *args[, method, ...])

Initializes a module method with variables and returns output and modified variables.

is_initializing()

Returns True if running under self.init(...) or nn.init(...)().

is_mutable_collection(col)

Returns true if the collection col is mutable.

lazy_init(rngs, *args[, method, mutable])

Initializes a module without computing on an actual input.

make_rng([name])

Returns a new RNG key from a given RNG sequence for this Module.

param(name, init_fn, *init_args[, unbox])

Declares and returns a parameter in this Module.

perturb(name, value[, collection])

Add an zero-value variable ('perturbation') to the intermediate value.

potential_gradient_fn(params)

Return a function returning a vector or the gradient of the potential.

potential_value_fn(params[, ...])

Return a function giving the value of the potential.

put_variable(col, name, value)

Updates the value of the given variable if it is mutable, or an error otherwise.

setup()

Initializes a Module lazily (similar to a lazy __init__).

sow(col, name, value[, reduce_fn, init_fn])

Stores a value in a collection.

unbind()

Returns an unbound copy of a Module and its variables.

variable(col, name[, init_fn, unbox])

Declares and returns a variable in this Module.

Attributes

act_fn(x)

gaussian_map_samples

is_potential

Indicates if the module implements a potential value or a vector field.

name

parent

path

pos_weights

ranks

rectifier_fn(x)

scope

variables

Returns the variables in this module.

dim_data

dim_hidden