ott.solvers.quadratic.gw_barycenter.GWBarycenterState

ott.solvers.quadratic.gw_barycenter.GWBarycenterState#

class ott.solvers.quadratic.gw_barycenter.GWBarycenterState(cost=None, x=None, a=None, errors=None, costs=None, costs_bary=None, gw_convergence=None)[source]#

State of the GW barycenter problem.

Parameters:
  • cost (Array | None) – Barycenter cost matrix of shape [bar_size, bar_size].

  • x (Array | None) – Barycenter features of shape [bar_size, ndim_fused]. Only used in the fused case.

  • a (Array | None) – Weights of the barycenter of shape [bar_size,].

  • errors (Array | None) – Array of shape [max_iter, num_measures, quad_max_iter, lin_outer_iter] containing the GW errors at each iteration.

  • costs (Array | None) – Array of shape [max_iter,] containing the cost at each iteration.

  • costs_bary (Array | None) – Array of shape [max_iter, num_measures] containing the cost between the individual measures and the barycenter at each iteration.

  • gw_convergence (Array | None) – Array of shape [max_iter,] containing the convergence of all GW problems at each iteration.

Methods

count(value, /)

Return number of occurrences of value.

index(value[, start, stop])

Return first index of value.

set(**kwargs)

Return a copy of self, possibly with overwrites.

Attributes

a

Alias for field number 2

cost

Alias for field number 0

costs

Alias for field number 4

costs_bary

Alias for field number 5

errors

Alias for field number 3

gw_convergence

Alias for field number 6

n_iters

Number of iterations.

x

Alias for field number 1