ott.experimental.mmsinkhorn.MMSinkhornOutput

Contents

ott.experimental.mmsinkhorn.MMSinkhornOutput#

class ott.experimental.mmsinkhorn.MMSinkhornOutput(potentials, errors, x_s=None, a_s=None, cost_fns=None, epsilon=None, ent_reg_cost=None, threshold=None, converged=None, inner_iterations=None)[source]#

Output of the MMSinkhorn solver used on \(k\) point clouds.

This class contains both solutions and problem definition of a regularized MM-OT problem involving \(k\) weighted point clouds of varying sizes, along with methods and properties that can use or describe the solution.

Parameters:
  • potentials (Tuple[Array, ...]) – Tuple of \(k\) optimal dual variables, vectors of sizes equal to the number of points in each of the \(k\) point clouds.

  • errors (Array) – Vector of errors, along iterations. This vector is of size max_iterations // inner_iterations where those were the parameters passed on to the MMSinkhorn solver. Follows the conventions used in errors

  • x_s (Tuple[Array, ...] | None) – Tuple of \(k\) point clouds, x_s[i] is a matrix of size \(n_i \times d\) where d is common to all point clouds.

  • a_s (Tuple[Array, ...] | None) – Tuple of \(k\) probability vectors, each of size \(n_i\).

  • cost_fns (CostFn | Tuple[CostFn, ...] | None) – Cost function, or a tuple of \(k(k-1)/2\) such instances.

  • epsilon (float | None) – Entropic regularization used to solve the multimarginal Sinkhorn problem.

  • ent_reg_cost (Array | None) – The regularized optimal transport cost, the linear contribution (dot product between optimal tensor and cost) minus entropy times epsilon.

  • threshold (Array | None) – Convergence threshold used to control the termination of the algorithm.

  • converged (bool | None) – Whether the output corresponds to a solution whose error is below the convergence threshold.

  • inner_iterations (int | None) – Number of iterations that were run between two computations of errors.

Methods

count(value, /)

Return number of occurrences of value.

index(value[, start, stop])

Return first index of value.

marginal(k)

Return the marginal probability weight vector at slice \(k\).

set(**kwargs)

Return a copy of self, with potential overwrites.

Attributes

a_s

Alias for field number 3

converged

Alias for field number 8

cost_fns

Alias for field number 4

cost_t

Cost tensor.

ent_reg_cost

Alias for field number 6

epsilon

Alias for field number 5

errors

Alias for field number 1

inner_iterations

Alias for field number 9

marginals

\(k\) marginal probability weight vectors.

n_iters

Total number of iterations that were needed to terminate.

n_marginals

Number of marginals.

potentials

Alias for field number 0

shape

Shape of the transport tensor.

tensor

Transport tensor.

threshold

Alias for field number 7

transport_mass

Sum of transport tensor.

x_s

Alias for field number 2