ott.solvers.nn.icnn.ICNN
ott.solvers.nn.icnn.ICNN#
- class ott.solvers.nn.icnn.ICNN(dim_data, dim_hidden, init_std=0.1, init_fn=<function normal>, act_fn=<jax._src.custom_derivatives.custom_jvp object>, pos_weights=True, gaussian_map=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Input convex neural network (ICNN) architecture with initialization.
Implementation of input convex neural networks as introduced in [Amos et al., 2017] with initialization schemes proposed by [Bunne et al., 2022].
- Parameters
dim_data (
int
) – data dimensionality.dim_hidden (
Sequence
[int
]) – sequence specifying size of hidden dimensions. The output dimension of the last layer is 1 by default.init_std (
float
) – value of standard deviation of weight initialization method.init_fn (
Callable
) – choice of initialization method for weight matrices (default: jax.nn.initializers.normal).act_fn (
Callable
) – choice of activation function used in network architecture (needs to be convex, default: nn.relu).pos_weights (
bool
) – choice to enforce positivity of weight or use regularizer.gaussian_map (
Optional
[Tuple
[Array
,Array
]]) – data inputs of source and target measures for initialization scheme based on Gaussian approximation of input and target measure (if None, identity initialization is used).parent (Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]]) –
- Return type
None
Methods
apply
(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
bind
(variables, *args[, rngs, mutable])Creates an interactive Module instance by binding variables and RNGs.
clone
(*[, parent])Creates a clone of this Module, with optionally updated arguments.
create_train_state
(rng, optimizer, input)Create initial TrainState.
get_variable
(col, name[, default])Retrieves the value of a Variable.
has_rng
(name)Returns true if a PRNGSequence with name name exists.
has_variable
(col, name)Checks if a variable of given collection and name exists in this Module.
init
(rngs, *args[, method, mutable, ...])Initializes a module method with variables and returns modified variables.
init_fn
([dtype])Builds an initializer that returns real normally-distributed random arrays.
init_with_output
(rngs, *args[, method, ...])Initializes a module method with variables and returns output and modified variables.
Returns True if running under self.init(...) or nn.init(...)().
Returns true if the collection col is mutable.
make_rng
(name)Returns a new RNG key from a given RNG sequence for this Module.
param
(name, init_fn, *init_args[, unbox])Declares and returns a parameter in this Module.
perturb
(name, value[, collection])Add an zero-value variable ('perturbation') to the intermediate value.
put_variable
(col, name, value)Updates the value of the given variable if it is mutable, or an error otherwise.
setup
()Initializes a Module lazily (similar to a lazy
__init__
).sow
(col, name, value[, reduce_fn, init_fn])Stores a value in a collection.
tabulate
(rngs, *args[, depth, ...])Creates a summary of the Module represented as a table.
unbind
()Returns an unbound copy of a Module and its variables.
variable
(col, name[, init_fn, unbox])Declares and returns a variable in this Module.
Attributes
act_fn
(x)Returns the variables in this module.