ott.experimental.mmsinkhorn.MMSinkhornOutput#
- class ott.experimental.mmsinkhorn.MMSinkhornOutput(potentials, errors, x_s=None, a_s=None, cost_fns=None, epsilon=None, ent_reg_cost=None, threshold=None, converged=None, inner_iterations=None)[source]#
Output of the MMSinkhorn solver used on \(k\) point clouds.
This class contains both solutions and problem definition of a regularized MM-OT problem involving \(k\) weighted point clouds of varying sizes, along with methods and properties that can use or describe the solution.
- Parameters:
potentials (
Tuple[Array,...]) – Tuple of \(k\) optimal dual variables, vectors of sizes equal to the number of points in each of the \(k\) point clouds.errors (
Array) – Vector of errors, along iterations. This vector is of sizemax_iterations // inner_iterationswhere those were the parameters passed on to theMMSinkhornsolver. Follows the conventions used inerrorsx_s (
Optional[Tuple[Array,...]]) – Tuple of \(k\) point clouds,x_s[i]is a matrix of size \(n_i \times d\) where d is common to all point clouds.a_s (
Optional[Tuple[Array,...]]) – Tuple of \(k\) probability vectors, each of size \(n_i\).cost_fns (
Union[CostFn,Tuple[CostFn,...],None]) – Cost function, or a tuple of \(k(k-1)/2\) such instances.epsilon (
Optional[float]) – Entropic regularization used to solve the multimarginal Sinkhorn problem.ent_reg_cost (
Optional[Array]) – The regularized optimal transport cost, the linear contribution (dot product between optimal tensor and cost) minus entropy timesepsilon.threshold (
Optional[Array]) – Convergence threshold used to control the termination of the algorithm.converged (
Optional[bool]) – Whether the output corresponds to a solution whose error is below the convergence threshold.inner_iterations (
Optional[int]) – Number of iterations that were run between two computations of errors.
Methods
count(value, /)Return number of occurrences of value.
index(value[, start, stop])Return first index of value.
marginal(k)Return the marginal probability weight vector at slice \(k\).
set(**kwargs)Return a copy of self, with potential overwrites.
Attributes
Alias for field number 3
Alias for field number 8
Alias for field number 4
Cost tensor.
Alias for field number 6
Alias for field number 5
Alias for field number 1
Alias for field number 9
\(k\) marginal probability weight vectors.
Total number of iterations that were needed to terminate.
Number of marginals.
Alias for field number 0
Shape of the transport
tensor.Transport tensor.
Alias for field number 7
Sum of transport tensor.
Alias for field number 2