ott.core.icnn.ICNN#

class ott.core.icnn.ICNN(dim_hidden, init_std=0.1, init_fn=<function normal>, act_fn=<jax._src.custom_derivatives.custom_jvp object>, pos_weights=True, dim_data=2, gaussian_map=None, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Input convex neural network (ICNN) architecture with initialization.

Implementation of input convex neural networks as introduced in [Amos et al., 2017] with initialization schemes proposed by [Bunne et al., 2022].

Parameters
  • dim_hidden (Sequence[int]) – sequence specifying size of hidden dimensions. The output dimension of the last layer is 1 by default.

  • init_std (float) – value of standard deviation of weight initialization method.

  • init_fn (Callable) – choice of initialization method for weight matrices (default: jax.nn.initializers.normal).

  • act_fn (Callable) – choice of activation function used in network architecture (needs to be convex, default: nn.leaky_relu).

  • pos_weights (bool) – choice to enforce positivity of weight or use regularizer.

  • dim_data (int) – data dimensionality (default: 2).

  • gaussian_map (Optional[Tuple[ndarray, ndarray]]) – data inputs of source and target measures for initialization scheme based on Gaussian approximation of input and target measure (if None, identity initialization is used).

  • parent (Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]]) –

  • name (str) –

Return type

None

Methods

apply(variables, *args[, rngs, method, ...])

Applies a module method to variables and returns output and modified variables.

bind(variables, *args[, rngs, mutable])

Creates an interactive Module instance by binding variables and RNGs.

clone(*[, parent])

Creates a clone of this Module, with optionally updated arguments.

get_variable(col, name[, default])

Retrieves the value of a Variable.

has_rng(name)

Returns true if a PRNGSequence with name name exists.

has_variable(col, name)

Checks if a variable of given collection and name exists in this Module.

init(rngs, *args[, method, mutable, ...])

Initializes a module method with variables and returns modified variables.

init_fn([dtype])

Builds an initializer that returns real normally-distributed random arrays.

init_with_output(rngs, *args[, method, ...])

Initializes a module method with variables and returns output and modified variables.

is_initializing()

Returns True if running under self.init(...) or nn.init(...)().

is_mutable_collection(col)

Returns true if the collection col is mutable.

make_rng(name)

Returns a new RNG key from a given RNG sequence for this Module.

param(name, init_fn, *init_args)

Declares and returns a parameter in this Module.

put_variable(col, name, value)

Sets the value of a Variable.

setup()

Initializes a Module lazily (similar to a lazy __init__).

sow(col, name, value[, reduce_fn, init_fn])

Stores a value in a collection.

tabulate(rngs, *args[, method, mutable, ...])

Creates a summary of the Module represented as a table.

variable(col, name[, init_fn])

Declares and returns a variable in this Module.

Attributes

act_fn(x)

dim_data

gaussian_map

init_std

name

parent

pos_weights

scope

variables

Returns the variables in this module.

dim_hidden