ott.neural.networks.layers.posdef.PositiveDense

Contents

ott.neural.networks.layers.posdef.PositiveDense#

class ott.neural.networks.layers.posdef.PositiveDense(in_features, out_features, *, rectifier_fn=<PjitFunction of <function softplus>>, use_softmax=False, use_sinkhorn=False, use_bias=True, kernel_init=<function variance_scaling.<locals>.init>, bias_init=<function zeros>, rngs)[source]#

A linear transformation with non-negative weights.

Three modes for enforcing positivity:

  • Element-wise rectifier (default): applies rectifier_fn (e.g., softplus, relu) to each weight independently.

  • Softmax (use_softmax=True): column-wise softmax so each column sums to 1, producing stochastic weight matrices.

  • Sinkhorn (use_sinkhorn=True): Sinkhorn normalization in log-space produces approximately doubly-stochastic matrices.

Parameters:
  • in_features (int) – Input dimension.

  • out_features (int) – Output dimension.

  • rectifier_fn (Optional[Callable[[Array], Array]]) – Function to enforce non-negativity. Ignored when use_softmax or use_sinkhorn is True.

  • use_softmax (bool) – If True, use column-wise softmax normalization.

  • use_sinkhorn (bool) – If True, use Sinkhorn normalization.

  • use_bias (bool) – Whether to add a bias term.

  • kernel_init (Union[Initializer, Callable[..., Any]]) – Initializer for the kernel.

  • bias_init (Union[Initializer, Callable[..., Any]]) – Initializer for the bias.

  • rngs (Rngs) – Random number generators.

  • args (Any)

  • kwargs (Any)

Return type:

Any

Methods

eval(**attributes)

Sets the Module to evaluation mode.

iter_children()

Warning: this method is method is deprecated; use iter_children() instead.

iter_modules()

Warning: this method is method is deprecated; use iter_modules() instead.

perturb(name, value[, variable_type])

Extract gradients of intermediate values during training.

sow(variable_type, name, value[, reduce_fn, ...])

Store intermediate values during module execution for later extraction.

train(**attributes)

Sets the Module to training mode.