ott.neural.layers.PosDefPotentials.bias_init

ott.neural.layers.PosDefPotentials.bias_init#

PosDefPotentials.bias_init(shape, dtype=<class 'jax.numpy.float64'>)#

An initializer that returns a constant array full of zeros.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.zeros(jax.random.PRNGKey(42), (2, 3), jnp.float32)
Array([[0., 0., 0.],
       [0., 0., 0.]], dtype=float32)
Parameters:
Return type:

Array