- Geometry.apply_lse_kernel(f, g, eps, vec=None, axis=0)[source]#
kernel_matrixin log domain.
This function applies the ground geometry’s kernel in log domain, using a stabilized formulation. At a high level, this iteration performs either:
output = eps * log (K (exp(g / eps) * vec)) (1)
output = eps * log (K’(exp(f / eps) * vec)) (2)
K is implicitly exp(-cost_matrix/eps).
To carry this out in a stabilized way, we take advantage of the fact that the entries of the matrix
f[:,*] + g[*,:] - Care all negative, and therefore their exponential never overflows, to add (and subtract after) f and g in iterations 1 & 2 respectively.
Array) – jnp.ndarray [num_a,] , potential of size num_rows of cost_matrix
Array) – jnp.ndarray [num_b,] , potential of size num_cols of cost_matrix
float) – float, regularization strength
Array]) – jnp.ndarray [num_a or num_b,] , when not None, this has the effect of doing log-Kernel computations with an addition elementwise multiplication of exp(g / eps) by a vector. This is carried out by adding weights to the log-sum-exp function, and needs to handle signs separately.
int) – summing over axis 0 when doing (2), or over axis 1 when doing (1)
- Return type:
A jnp.ndarray corresponding to output above, depending on axis.