ott.neural.networks.icnn.ICNN#
- class ott.neural.networks.icnn.ICNN(dim_data, dim_hidden, ranks=1, init_fn=<function <lambda>>, act_fn=<jax._src.custom_derivatives.custom_jvp object>, pos_weights=False, rectifier_fn=<jax._src.custom_derivatives.custom_jvp object>, gaussian_map_samples=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Input convex neural network (ICNN).
Implementation of input convex neural networks as introduced in [Amos et al., 2017] with initialization schemes proposed by [Bunne et al., 2022], and (low-rank + diagonal) quadratic on inputs at each layer, by [Vesseron and Cuturi, 2024].
- Parameters:
dim_data (
int
) – data dimensionality.dim_hidden (
Sequence
[int
]) – sequence specifying size of hidden dimensions. The output dimension of the last layer is 1 by default.ranks (
Union
[int
,Tuple
[int
,...
]]) – ranks of the matrices \(A_i\) used as low-rank factors for the quadratic potentials. If a sequence is passed, it must containlen(dim_hidden) + 2
elements, where the last 2 elements correspond to the ranks of the final layer with dimension 1 and the potentials, respectively.init_fn (
Callable
[[Array
,Tuple
[int
,...
],Any
],Array
]) – Initializer for the kernel weight matrices. The default isnormal()
.act_fn (
Callable
[[Array
],Array
]) – choice of activation function used in network architecture, needs to be convex. The default isrelu()
.pos_weights (
bool
) – Enforce positive weights with a projection. IfFalse
, the positive weights should be enforced with clipping or regularization in the loss.rectifier_fn (
Callable
[[Array
],Array
]) – function to ensure the non negativity of the weights. The default isrelu()
.gaussian_map_samples (
Optional
[Tuple
[Array
,Array
]]) – Tuple of source and target points, used to initialize the ICNN to mimic the linear Bures map that morphs the (Gaussian approximation) of the input measure to that of the target measure. IfNone
, the identity initialization is used, and ICNN mimics half the squared Euclidean norm.
Methods
apply
(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
clone
(*[, parent, _deep_clone, _reset_names])Creates a clone of this Module, with optionally updated arguments.
copy
(*[, parent, name])Creates a copy of this Module, with optionally updated arguments.
create_train_state
(rng, optimizer, input, ...)Create initial training state.
get_variable
(col, name[, default])Retrieves the value of a Variable.
has_rng
(name)Returns true if a PRNGSequence with name
name
exists.has_variable
(col, name)Checks if a variable of given collection and name exists in this Module.
init
(rngs, *args[, method, mutable, ...])Initializes a module method with variables and returns modified variables.
init_fn
(**k)init_with_output
(rngs, *args[, method, ...])Initializes a module method with variables and returns output and modified variables.
Returns True if running under self.init(...) or nn.init(...)().
Returns true if the collection
col
is mutable.lazy_init
(rngs, *args[, method, mutable])Initializes a module without computing on an actual input.
make_rng
([name])Returns a new RNG key from a given RNG sequence for this Module.
param
(name, init_fn, *init_args[, unbox])Declares and returns a parameter in this Module.
perturb
(name, value[, collection])Add an zero-value variable ('perturbation') to the intermediate value.
potential_gradient_fn
(params)Return a function returning a vector or the gradient of the potential.
potential_value_fn
(params[, ...])Return a function giving the value of the potential.
put_variable
(col, name, value)Updates the value of the given variable if it is mutable, or an error otherwise.
setup
()Initializes a Module lazily (similar to a lazy
__init__
).sow
(col, name, value[, reduce_fn, init_fn])Stores a value in a collection.
unbind
()Returns an unbound copy of a Module and its variables.
variable
(col, name[, init_fn, unbox])Declares and returns a variable in this Module.
Attributes
act_fn
(x)Indicates if the module implements a potential value or a vector field.
rectifier_fn
(x)Returns the variables in this module.