ott.geometry.grid.Grid
ott.geometry.grid.Grid#
- class ott.geometry.grid.Grid(x=None, grid_size=None, cost_fns=None, num_a=None, grid_dimension=None, **kwargs)[source]#
Class describing the geometry of points taken in a cartestian product.
This class implements a geometry in which probability measures are supported on a \(d\)-dimensional cartesian grid, a cartesian product of \(d\) lists of values, each list being itself of size \(n_i\).
The transportation cost between points in the grid is assumed to be separable, namely a sum of coordinate-wise cost functions, as in
\[cost(x,y) = \sum_{i=1}^d cost_i(x_i, y_i)\]where \(cost_i\): R x R → R.
In such a regime, and despite the fact that the total number \(n_{total}\) of points in the grid is exponential \(d\) (namely \(\prod_i n_i\)), applying a kernel in the context of regularized optimal transport can be carried out in time that is of the order of \(n_{total}^{(1+1/d)}\) using convolutions, either in the original domain or log-space domain. This class precomputes \(d\) \(n_i\) x \(n_i\) cost matrices (one per dimension) and implements these two operations by carrying out these convolutions one dimension at a time.
- Parameters
x (
Optional
[Sequence
[ndarray
]]) – list of arrays of varying sizes, describing the locations of the grid. Locations are provided as a list of jnp.ndarrays, that is \(d\) vectors of (possibly varying) size \(n_i\). The resulting grid is the Cartesian product of these vectors.grid_size (
Optional
[Sequence
[int
]]) – tuple of integers describing grid sizes, namely \((n_1,...,n_d)\). This will only be used if x is None. In that case the grid will be assumed to lie in the hypercube \([0,1]^d\), with the \(d\) dimensions, described as points regularly sampled in [0,1].cost_fns (
Optional
[Sequence
[CostFn
]]) – a sequence of \(d\) costs.CostFn’s, each being a cost taking two reals as inputs to output a real number.num_a (
Optional
[int
]) – total size of grid. This parameters will be computed from other inputs and used in the flatten/unflatten functions.grid_dimension (
Optional
[int
]) – dimension of grid. This parameters will be computed from other inputs and used in the flatten/unflatten functions.kwargs (
Any
) – other optional parameters to be passed on to superclass initializer, notably those related to epsilon regularization.
Methods
apply_cost
(arr[, axis, fn])Apply cost matrix to array (vector or matrix).
apply_kernel
(scaling[, eps, axis])Apply grid kernel on scaling vector.
apply_lse_kernel
(f, g, eps[, vec, axis])Apply grid kernel in log space.
apply_square_cost
(arr[, axis])Apply elementwise-square of cost matrix to array (vector or matrix).
apply_transport_from_potentials
(f, g, vec[, ...])Apply transport matrix computed from potentials to a (batched) vec.
apply_transport_from_scalings
(u, v, vec[, axis])Apply transport matrix computed from scalings to a (batched) vec.
copy_epsilon
(other)Copy the epsilon parameters from another geometry.
marginal_from_potentials
(f, g[, axis])Output marginal of transportation matrix from potentials.
marginal_from_scalings
(u, v[, axis])Output marginal of transportation matrix from scalings.
potential_from_scaling
(scaling)Compute dual potential vector from scaling vector.
prepare_divergences
(*args[, static_b])Instantiate the geometries used for a divergence computation.
rescale_cost_fn
(factor)Rescale the cost or kernel matrix using a factor.
scaling_from_potential
(potential)Compute scaling vector from dual potential.
transport_from_potentials
(f, g[, axis])Output transport matrix from potentials.
transport_from_scalings
(f, g[, axis])Output transport matrix from pair of scalings.
update_potential
(f, g, log_marginal[, ...])Carry out one Sinkhorn update for potentials, i.e. in log space.
update_scaling
(scaling, marginal[, ...])Carry out one Sinkhorn update for scalings, using kernel directly.
Attributes
Cost matrix, recomputed from kernel if only kernel was specified.
Output rank of cost matrix, if any was provided.
Epsilon regularization value.
Compute and return inverse of scaling factor for cost matrix.
Whether geometry cost/kernel should be recomputed on the fly.
Whether cost is computed by taking squared-Eucl.
Whether geometry cost/kernel is a symmetric matrix.
Kernel matrix, either provided by user or recomputed from cost.
Mean of cost matrix.
Median of cost matrix.
Compute the scale of the epsilon, potentially based on data.
Shape of cost or kernel matrix.