ott.neural.networks.layers.posdef.PosDefPotentials.kernel_diag_init

ott.neural.networks.layers.posdef.PosDefPotentials.kernel_diag_init#

PosDefPotentials.kernel_diag_init(shape, dtype=<class 'jax.numpy.float64'>)#

An initializer that returns a constant array full of ones.

The key argument is ignored.

>>> import jax, jax.numpy as jnp
>>> jax.nn.initializers.ones(jax.random.key(42), (3, 2), jnp.float32)
Array([[1., 1.],
       [1., 1.],
       [1., 1.]], dtype=float32)
Parameters:
Return type:

Array