scaling (Array
) – jnp.ndarray of num_a or num_b positive values.
marginal (Array
) – targeted marginal
iteration (Optional
[int
]) – used to compute epsilon from schedule, if provided.
axis (int
) – axis along which the update should be carried out.