ott.core.gw_barycenter.GromovWassersteinBarycenter.init_state#

GromovWassersteinBarycenter.init_state(problem, bar_size, bar_init=None, a=None, seed=0)[source]#

Initialize the (fused) Gromov-Wasserstein barycenter state.

Parameters
  • problem (GWBarycenterProblem) – The barycenter problem.

  • bar_size (int) – Size of the barycenter.

  • bar_init (Union[ndarray, Tuple[ndarray, ndarray], None]) –

    Initial barycenter value. Can be one of the following:

    • None - randomly initialize the barycenter.

    • jax.numpy.ndarray - barycenter cost matrix of shape [bar_size, bar_size]. Only used in the non-fused case.

    • 2- tuple of jax.numpy.ndarray - the 1st array corresponds to [bar_size, bar_size] cost matrix, the 2nd array is [bar_size, ndim_fused] a feature matrix used in the fused case.

  • a (Optional[ndarray]) – An array of shape [bar_size,] containing the barycenter weights.

  • seed (int) – Random seed used when bar_init = None.

Return type

GWBarycenterState

Returns

The initial barycenter state.