ott.geometry.geodesic.Geodesic.apply_lse_kernel#
- Geodesic.apply_lse_kernel(f, g, eps, vec=None, axis=0)#
Apply
kernel_matrix
in log domain.This function applies the ground geometry’s kernel in log domain, using a stabilized formulation. At a high level, this iteration performs either:
output = eps * log (K (exp(g / eps) * vec)) (1)
output = eps * log (K’(exp(f / eps) * vec)) (2)
K is implicitly exp(-cost_matrix/eps).
To carry this out in a stabilized way, we take advantage of the fact that the entries of the matrix
f[:,*] + g[*,:] - C
are all negative, and therefore their exponential never overflows, to add (and subtract after) f and g in iterations 1 & 2 respectively.- Parameters:
f (
Array
) – jnp.ndarray [num_a,] , potential of size num_rows of cost_matrixg (
Array
) – jnp.ndarray [num_b,] , potential of size num_cols of cost_matrixeps (
float
) – float, regularization strengthvec (
Optional
[Array
]) – jnp.ndarray [num_a or num_b,] , when not None, this has the effect of doing log-Kernel computations with an addition elementwise multiplication of exp(g / eps) by a vector. This is carried out by adding weights to the log-sum-exp function, and needs to handle signs separately.axis (
int
) – summing over axis 0 when doing (2), or over axis 1 when doing (1)
- Return type:
- Returns:
A jnp.ndarray corresponding to output above, depending on axis.