scaling (Array) – jnp.ndarray of num_a or num_b positive values.
marginal (Array) – targeted marginal
iteration (Optional[int]) – used to compute epsilon from schedule, if provided.
axis (int) – axis along which the update should be carried out.