ott.geometry.costs.ElasticSTVS

Contents

ott.geometry.costs.ElasticSTVS#

class ott.geometry.costs.ElasticSTVS(scaling_reg=1.0, matrix=None, orthogonal=False)[source]#

Cost with soft thresholding operator with vanishing shrinkage (STVS) [Schreck et al., 2016] regularization.

\[\frac{1}{2} \|\cdot\|_2^2 + \text{scaling_reg}^2\mathbf{1}_d^T\left(\sigma(\cdot) - \frac{1}{2} \exp\left(-2\sigma(\cdot)\right) + \frac{1}{2}\right)\]

where \(\sigma(\cdot) := \text{asinh}\left(\frac{\cdot} {2\text{scaling_reg}}\right)\)

Parameters:
  • scaling_reg (float) – Strength of the regularization.

  • matrix (Optional[Array]) – \(p \times d\) projection matrix with orthogonal rows.

  • orthogonal (bool) – Whether to regularize in the orthogonal complement to promote displacements in the span of matrix.

Methods

all_pairs(x, y)

Compute matrix of all pairwise costs, including the norms.

all_pairs_pairwise(x, y)

Compute matrix of all pairwise costs, excluding the norms.

barycenter(weights, xs)

Output barycenter of vectors.

h(z)

TI function acting on difference of \(x-y\) to output cost.

h_legendre(z)

Legendre transform of h() when it is convex.

h_transform(f, **kwargs)

Compute the h-transform of a concave function.

pairwise(x, y)

Compute cost as evaluation of h() on \(x-y\).

prox_legendre_reg(z[, tau])

Proximal operator of the Legendre transform of reg().

prox_reg(z[, tau])

Proximal operator of reg().

reg(z)

Regularization function.

twist_operator(vec, dual_vec, variable)

Twist inverse operator of the cost function.

Attributes

norm