ott.geometry.low_rank.LRCGeometry#

class ott.geometry.low_rank.LRCGeometry(cost_1, cost_2, bias=0.0, scale_factor=1.0, scale_cost=1.0, batch_size=None, **kwargs)[source]#

Geometry whose cost is defined by product of two low-rank matrices.

Implements geometries that are defined as low rank products, i.e. for which there exists two matrices \(A\) and \(B\) of \(r\) columns such that the cost of the geometry equals \(AB^T\). Apart from being faster to apply to a vector, these geometries are characterized by the fact that adding two such geometries should be carried out by concatenating factors, i.e. if \(C = AB^T\) and \(D = EF^T\) then \(C + D = [A,E][B,F]^T\)

Parameters
  • cost_1 (Array) – jnp.ndarray<float>[num_a, r]

  • cost_2 (Array) – jnp.ndarray<float>[num_b, r]

  • bias (float) – constant added to entire cost matrix.

  • scale – Value used to rescale the factors of the low-rank geometry.

  • scale_cost (Union[bool, int, float, Literal[‘mean’, ‘max_bound’, ‘max_cost’]]) – option to rescale the cost matrix. Implemented scalings are ‘max_bound’, ‘mean’ and ‘max_cost’. Alternatively, a float factor can be given to rescale the cost such that cost_matrix /= scale_cost. If True, use ‘mean’.

  • batch_size (Optional[int]) – optional size of the batch to compute online (without instantiating the matrix) the scale factor scale_cost of the cost_matrix when scale_cost = 'max_cost'. If None, the batch size is set to 1024 or to the largest number of samples between cost_1 and cost_2 if smaller than 1024.

  • kwargs (Any) – Additional keyword arguments for Geometry.

  • scale_factor (float) –

Methods

apply_cost(arr[, axis, fn])

Apply cost_matrix to array (vector or matrix).

apply_kernel(scaling[, eps, axis])

Apply kernel_matrix on positive scaling vector.

apply_lse_kernel(f, g, eps[, vec, axis])

Apply kernel_matrix in log domain on a pair of dual potential variables.

apply_square_cost(arr[, axis])

Apply elementwise-square of cost matrix to array (vector or matrix).

apply_transport_from_potentials(f, g, vec[, ...])

Apply transport matrix computed from potentials to a (batched) vec.

apply_transport_from_scalings(u, v, vec[, axis])

Apply transport matrix computed from scalings to a (batched) vec.

compute_max_cost()

Compute the maximum of the cost_matrix.

copy_epsilon(other)

Copy the epsilon parameters from another geometry.

marginal_from_potentials(f, g[, axis])

Output marginal of transportation matrix from potentials.

marginal_from_scalings(u, v[, axis])

Output marginal of transportation matrix from scalings.

mask(src_mask, tgt_mask[, mask_value])

Mask rows or columns of a geometry.

potential_from_scaling(scaling)

Compute dual potential vector from scaling vector.

prepare_divergences(*args[, static_b])

Instantiate 2 (or 3) geometries to compute a Sinkhorn divergence.

scaling_from_potential(potential)

Compute scaling vector from dual potential.

subset(src_ixs, tgt_ixs, **kwargs)

Subset rows or columns of a geometry.

to_LRCGeometry([rank, tol, seed])

Return self.

transport_from_potentials(f, g)

Output transport matrix from potentials.

transport_from_scalings(u, v)

Output transport matrix from pair of scalings.

update_potential(f, g, log_marginal[, ...])

Carry out one Sinkhorn update for potentials, i.e. in log space.

update_scaling(scaling, marginal[, ...])

Carry out one Sinkhorn update for scalings, using kernel directly.

Attributes

bias

Constant offset added to the entire cost_matrix.

can_LRC

Check quickly if casting geometry as LRC makes sense.

cost_1

First factor of the cost_matrix.

cost_2

Second factor of the cost_matrix.

cost_matrix

Materialize the cost matrix.

cost_rank

Output rank of cost matrix, if any was provided.

dtype

The data type.

epsilon

Epsilon regularization value.

inv_scale_cost

Compute and return inverse of scaling factor for cost matrix.

is_online

Whether geometry cost/kernel should be recomputed on the fly.

is_squared_euclidean

Whether cost is computed by taking squared-Eucl.

is_symmetric

Whether geometry cost/kernel is a symmetric matrix.

kernel_matrix

Kernel matrix, either provided by user or recomputed from cost_matrix.

mean_cost_matrix

Mean of the cost_matrix.

median_cost_matrix

Median of the cost_matrix.

scale_epsilon

Compute the scale of the epsilon, potentially based on data.

shape

Shape of the geometry.

src_mask

Mask of shape [num_a,] to compute cost_matrix statistics.

tgt_mask

Mask of shape [num_b,] to compute cost_matrix statistics.