ott.geometry.low_rank.LRCGeometry.apply_lse_kernel

ott.geometry.low_rank.LRCGeometry.apply_lse_kernel#

LRCGeometry.apply_lse_kernel(f, g, eps, vec=None, axis=0)#

Apply kernel_matrix in log domain.

This function applies the ground geometry’s kernel in log domain, using a stabilized formulation. At a high level, this iteration performs either:

  • output = eps * log (K (exp(g / eps) * vec)) (1)

  • output = eps * log (K’(exp(f / eps) * vec)) (2)

K is implicitly exp(-cost_matrix/eps).

To carry this out in a stabilized way, we take advantage of the fact that the entries of the matrix f[:,*] + g[*,:] - C are all negative, and therefore their exponential never overflows, to add (and subtract after) f and g in iterations 1 & 2 respectively.

Parameters:
  • f (Array) – jnp.ndarray [num_a,] , potential of size num_rows of cost_matrix

  • g (Array) – jnp.ndarray [num_b,] , potential of size num_cols of cost_matrix

  • eps (float) – float, regularization strength

  • vec (Optional[Array]) – jnp.ndarray [num_a or num_b,] , when not None, this has the effect of doing log-Kernel computations with an addition elementwise multiplication of exp(g / eps) by a vector. This is carried out by adding weights to the log-sum-exp function, and needs to handle signs separately.

  • axis (int) – summing over axis 0 when doing (2), or over axis 1 when doing (1)

Return type:

Array

Returns:

A jnp.ndarray corresponding to output above, depending on axis.