ott.solvers.quadratic.gw_barycenter.GWBarycenterState
ott.solvers.quadratic.gw_barycenter.GWBarycenterState#
- class ott.solvers.quadratic.gw_barycenter.GWBarycenterState(cost=None, x=None, a=None, errors=None, costs=None, gw_convergence=None)[source]#
Holds the state of the
GWBarycenterProblem
.- Parameters
c – Barycenter cost matrix of shape
[bar_size, bar_size]
.x (Optional[jax.Array]) – Barycenter features of shape
[bar_size, ndim_fused]
. Only used in the fused case.a (Optional[jax.Array]) – Weights of the barycenter of shape
[bar_size,]
.errors (Optional[jax.Array]) – Array of shape
[max_iter, num_measures, quad_max_iter, lin_outer_iter]
containing the GW errors at each iteration.costs (Optional[jax.Array]) – Array of shape
[max_iter,]
containing the cost at each iteration.gw_convergence (Optional[jax.Array]) – Array of shape
[max_iter,]
containing the convergence of all GW problems at each iteration.cost (Optional[jax.Array]) –
Methods
count
(value, /)Return number of occurrences of value.
index
(value[, start, stop])Return first index of value.
set
(**kwargs)Return a copy of self, possibly with overwrites.
Attributes
Alias for field number 2
Alias for field number 0
Alias for field number 4
Alias for field number 3
Alias for field number 5
Alias for field number 1