ott.solvers.quadratic.gw_barycenter.GWBarycenterState#

class ott.solvers.quadratic.gw_barycenter.GWBarycenterState(cost=None, x=None, a=None, errors=None, costs=None, gw_convergence=None)[source]#

Holds the state of the GWBarycenterProblem.

Parameters
  • c – Barycenter cost matrix of shape [bar_size, bar_size].

  • x (Optional[jax.Array]) – Barycenter features of shape [bar_size, ndim_fused]. Only used in the fused case.

  • a (Optional[jax.Array]) – Weights of the barycenter of shape [bar_size,].

  • errors (Optional[jax.Array]) – Array of shape [max_iter, num_measures, quad_max_iter, lin_outer_iter] containing the GW errors at each iteration.

  • costs (Optional[jax.Array]) – Array of shape [max_iter,] containing the cost at each iteration.

  • gw_convergence (Optional[jax.Array]) – Array of shape [max_iter,] containing the convergence of all GW problems at each iteration.

  • cost (Optional[jax.Array]) –

Methods

count(value, /)

Return number of occurrences of value.

index(value[, start, stop])

Return first index of value.

set(**kwargs)

Return a copy of self, possibly with overwrites.

Attributes

a

Alias for field number 2

cost

Alias for field number 0

costs

Alias for field number 4

errors

Alias for field number 3

gw_convergence

Alias for field number 5

x

Alias for field number 1