ott.solvers.quadratic.gw_barycenter.GromovWassersteinBarycenter.init_state#
- GromovWassersteinBarycenter.init_state(problem, bar_size, bar_init=None, a=None, rng=None)[source]#
Initialize the (fused) Gromov-Wasserstein barycenter state.
- Parameters:
problem (
GWBarycenterProblem) – The barycenter problem.bar_size (
int) – Size of the barycenter.bar_init (
Union[Array,Tuple[Array,Array],None]) –Initial barycenter value. Can be one of the following:
None- randomly initialize the barycenter.jax.numpy.ndarray- barycenter cost matrix of shape[bar_size, bar_size]. Only used in the non-fused case.tupleofjax.numpy.ndarray- the first array corresponds to a cost matrix of shape[bar_size, bar_size], the second array is a[bar_size, ndim_fused]feature matrix used in the fused case.
a (
Optional[Array]) – An array of shape[bar_size,]containing the barycenter weights.rng (
Optional[Array]) – Random key for seeding used whenbar_init = None.
- Return type:
- Returns:
The initial barycenter state.