ott.solvers.nn.layers.PositiveDense
ott.solvers.nn.layers.PositiveDense#
- class ott.solvers.nn.layers.PositiveDense(dim_hidden, rectifier_fn=<CompiledFunction of <function softplus>>, inv_rectifier_fn=<function PositiveDense.<lambda>>, use_bias=True, dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
A linear transformation using a weight matrix with all entries positive.
- Parameters
dim_hidden (
int
) – the number of output dim_hidden.rectifier_fn (
Callable
[[Array
],Array
]) – choice of rectifier function (default: softplus function).inv_rectifier_fn (
Callable
[[Array
],Array
]) – choice of inverse rectifier function (default: inverse softplus function).dtype (
Any
) – the dtype of the computation (default: float32).precision (
Optional
[Any
]) – numerical precision of computation see jax.lax.Precision for details.kernel_init (
Callable
[[Any
,Tuple
[int
],Any
],Any
]) – initializer function for the weight matrix.bias_init (
Callable
[[Any
,Tuple
[int
],Any
],Any
]) – initializer function for the bias.use_bias (bool) –
parent (Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]]) –
- Return type
None
Methods
apply
(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
bias_init
(shape[, dtype])An initializer that returns a constant array full of zeros.
bind
(variables, *args[, rngs, mutable])Creates an interactive Module instance by binding variables and RNGs.
clone
(*[, parent])Creates a clone of this Module, with optionally updated arguments.
get_variable
(col, name[, default])Retrieves the value of a Variable.
has_rng
(name)Returns true if a PRNGSequence with name name exists.
has_variable
(col, name)Checks if a variable of given collection and name exists in this Module.
init
(rngs, *args[, method, mutable, ...])Initializes a module method with variables and returns modified variables.
init_with_output
(rngs, *args[, method, ...])Initializes a module method with variables and returns output and modified variables.
Returns True if running under self.init(...) or nn.init(...)().
Returns true if the collection col is mutable.
kernel_init
(shape[, dtype])- rtype
make_rng
(name)Returns a new RNG key from a given RNG sequence for this Module.
param
(name, init_fn, *init_args[, unbox])Declares and returns a parameter in this Module.
perturb
(name, value[, collection])Add an zero-value variable ('perturbation') to the intermediate value.
put_variable
(col, name, value)Sets the value of a Variable.
Softplus activation function.
setup
()Initializes a Module lazily (similar to a lazy
__init__
).sow
(col, name, value[, reduce_fn, init_fn])Stores a value in a collection.
tabulate
(rngs, *args[, depth, ...])Creates a summary of the Module represented as a table.
variable
(col, name[, init_fn, unbox])Declares and returns a variable in this Module.
Attributes
Returns the variables in this module.