ott.solvers.quadratic.gw_barycenter.GromovWassersteinBarycenter.init_state

ott.solvers.quadratic.gw_barycenter.GromovWassersteinBarycenter.init_state#

GromovWassersteinBarycenter.init_state(problem, bar_size, bar_init=None, a=None, rng=None)[source]#

Initialize the (fused) Gromov-Wasserstein barycenter state.

Parameters:
  • problem (GWBarycenterProblem) – The barycenter problem.

  • bar_size (int) – Size of the barycenter.

  • bar_init (Union[Array, Tuple[Array, Array], None]) –

    Initial barycenter value. Can be one of the following:

    • None - randomly initialize the barycenter.

    • jax.numpy.ndarray - barycenter cost matrix of shape [bar_size, bar_size]. Only used in the non-fused case.

    • tuple of jax.numpy.ndarray - the first array corresponds to a cost matrix of shape [bar_size, bar_size], the second array is a [bar_size, ndim_fused] feature matrix used in the fused case.

  • a (Optional[Array]) – An array of shape [bar_size,] containing the barycenter weights.

  • rng (Optional[Array]) – Random key for seeding used when bar_init = None.

Return type:

GWBarycenterState

Returns:

The initial barycenter state.