ott.solvers.nn.layers.PositiveDense#

class ott.solvers.nn.layers.PositiveDense(dim_hidden, rectifier_fn=<CompiledFunction of <function softplus>>, inv_rectifier_fn=<function PositiveDense.<lambda>>, use_bias=True, dtype=<class 'jax.numpy.float32'>, precision=None, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

A linear transformation using a weight matrix with all entries positive.

Parameters
  • dim_hidden (int) – the number of output dim_hidden.

  • rectifier_fn (Callable[[Array], Array]) – choice of rectifier function (default: softplus function).

  • inv_rectifier_fn (Callable[[Array], Array]) – choice of inverse rectifier function (default: inverse softplus function).

  • dtype (Any) – the dtype of the computation (default: float32).

  • precision (Optional[Any]) – numerical precision of computation see jax.lax.Precision for details.

  • kernel_init (Callable[[Any, Tuple[int], Any], Any]) – initializer function for the weight matrix.

  • bias_init (Callable[[Any, Tuple[int], Any], Any]) – initializer function for the bias.

  • use_bias (bool) –

  • parent (Optional[Union[Type[flax.linen.module.Module], Type[flax.core.scope.Scope], Type[flax.linen.module._Sentinel]]]) –

  • name (Optional[str]) –

Return type

None

Methods

apply(variables, *args[, rngs, method, ...])

Applies a module method to variables and returns output and modified variables.

bias_init(shape[, dtype])

An initializer that returns a constant array full of zeros.

bind(variables, *args[, rngs, mutable])

Creates an interactive Module instance by binding variables and RNGs.

clone(*[, parent])

Creates a clone of this Module, with optionally updated arguments.

get_variable(col, name[, default])

Retrieves the value of a Variable.

has_rng(name)

Returns true if a PRNGSequence with name name exists.

has_variable(col, name)

Checks if a variable of given collection and name exists in this Module.

init(rngs, *args[, method, mutable, ...])

Initializes a module method with variables and returns modified variables.

init_with_output(rngs, *args[, method, ...])

Initializes a module method with variables and returns output and modified variables.

inv_rectifier_fn()

is_initializing()

Returns True if running under self.init(...) or nn.init(...)().

is_mutable_collection(col)

Returns true if the collection col is mutable.

kernel_init(shape[, dtype])

rtype

Any

make_rng(name)

Returns a new RNG key from a given RNG sequence for this Module.

param(name, init_fn, *init_args[, unbox])

Declares and returns a parameter in this Module.

perturb(name, value[, collection])

Add an zero-value variable ('perturbation') to the intermediate value.

put_variable(col, name, value)

Sets the value of a Variable.

rectifier_fn()

Softplus activation function.

setup()

Initializes a Module lazily (similar to a lazy __init__).

sow(col, name, value[, reduce_fn, init_fn])

Stores a value in a collection.

tabulate(rngs, *args[, depth, ...])

Creates a summary of the Module represented as a table.

variable(col, name[, init_fn, unbox])

Declares and returns a variable in this Module.

Attributes

name

parent

precision

scope

use_bias

variables

Returns the variables in this module.

dim_hidden