ott.solvers.quadratic.gromov_wasserstein.GWOutput

Contents

ott.solvers.quadratic.gromov_wasserstein.GWOutput#

class ott.solvers.quadratic.gromov_wasserstein.GWOutput(costs=None, linear_convergence=None, converged=False, errors=None, linear_state=None, geom=None, old_transport_mass=1.0)[source]#

Holds the output of the Gromov-Wasserstein solver.

Parameters:
  • costs (Array | None) – Holds the sequence of regularized GW costs seen through the outer loop of the solver.

  • linear_convergence (Array | None) – Holds the sequence of bool convergence flags of the inner Sinkhorn iterations.

  • converged (bool) – Convergence flag for the outer GW iterations.

  • errors (Array | None) – Holds sequence of vectors of errors of the Sinkhorn algorithm at each iteration.

  • linear_state (SinkhornOutput | LRSinkhornOutput | None) – State used to solve and store solutions to the local linearization of GW.

  • geom (Geometry | None) – The geometry underlying the local linearization.

  • old_transport_mass (float) – Holds total mass of transport at previous iteration.

Methods

apply(inputs[, axis])

Apply the transport to an array; axis=1 for its transpose.

count(value, /)

Return number of occurrences of value.

index(value[, start, stop])

Return first index of value.

set(**kwargs)

Return a copy of self, possibly with overwrites.

Attributes

converged

Alias for field number 2

costs

Alias for field number 0

errors

Alias for field number 3

geom

Alias for field number 5

linear_convergence

Alias for field number 1

linear_state

Alias for field number 4

matrix

Transport matrix.

n_iters

old_transport_mass

Alias for field number 6

primal_cost

Return transport cost of current linear OT solution at geometry.

reg_gw_cost

Regularized optimal transport cost of the linearization.