```# Copyright OTT-JAX
#
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#
# Unless required by applicable law or agreed to in writing, software
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
from functools import partial
from typing import Any, Dict, NamedTuple, Optional, Sequence, Tuple, Union

import jax
import jax.numpy as jnp

from ott import utils
from ott.geometry import pointcloud
from ott.math import fixed_point_loop
from ott.problems.linear import linear_problem
from ott.solvers import was_solver

__all__ = ["GWBarycenterState", "GromovWassersteinBarycenter"]

[docs]
class GWBarycenterState(NamedTuple):
"""State of the GW barycenter problem.

Args:
cost: Barycenter cost matrix of shape ``[bar_size, bar_size]``.
x: Barycenter features of shape ``[bar_size, ndim_fused]``.
Only used in the fused case.
a: Weights of the barycenter of shape ``[bar_size,]``.
errors: Array of shape
the GW errors at each iteration.
costs: Array of shape ``[max_iter,]`` containing the cost at each iteration.
costs_bary: Array of shape ``[max_iter, num_measures]`` containing the
cost between the individual measures and the barycenter at each iteration.
gw_convergence: Array of shape ``[max_iter,]`` containing the convergence
of all GW problems at each iteration.
"""
cost: Optional[jnp.ndarray] = None
x: Optional[jnp.ndarray] = None
a: Optional[jnp.ndarray] = None
errors: Optional[jnp.ndarray] = None
costs: Optional[jnp.ndarray] = None
costs_bary: Optional[jnp.ndarray] = None
gw_convergence: Optional[jnp.ndarray] = None

[docs]
def set(self, **kwargs: Any) -> "GWBarycenterState":
"""Return a copy of self, possibly with overwrites."""
return self._replace(**kwargs)

@property
def n_iters(self) -> int:
"""Number of iterations."""
if self.gw_convergence is None:
return -1
return jnp.sum(self.gw_convergence != -1)

[docs]
@jax.tree_util.register_pytree_node_class
class GromovWassersteinBarycenter(was_solver.WassersteinSolver):
"""Gromov-Wasserstein barycenter solver.

Args:
epsilon: Entropy regularizer.
min_iterations: Minimum number of iterations.
max_iterations: Maximum number of outermost iterations.
threshold: Convergence threshold.
store_inner_errors: Whether to store the errors of the GW solver, as well
as its linear solver, at each iteration for each measure.
kwargs: Keyword argument for
Only used when ``quad_solver = None``.
"""

def __init__(
self,
epsilon: Optional[float] = None,
min_iterations: int = 5,
max_iterations: int = 50,
threshold: float = 1e-3,
store_inner_errors: bool = False,
# TODO(michalk8): maintain the API compatibility with `was_solver`
# but makes passing kwargs with the same name to `quad_solver` impossible
# will be fixed when refactoring the solvers
# note that `was_solver` also suffers from this
**kwargs: Any,
):
super().__init__(
epsilon=epsilon,
min_iterations=min_iterations,
max_iterations=max_iterations,
threshold=threshold,
store_inner_errors=store_inner_errors,
)
kwargs["epsilon"] = epsilon
# TODO(michalk8): store only GW errors?
kwargs["store_inner_errors"] = store_inner_errors
else:

def __call__(
self, problem: gw_barycenter.GWBarycenterProblem, bar_size: int,
**kwargs: Any
) -> GWBarycenterState:
"""Solver the (fused) GW barycenter problem.

Args:
problem: The GW barycenter problem.
bar_size: Size of the barycenter.
kwargs: Keyword arguments for :meth:`init_state`.

Returns:
The solution.
"""
state = self.init_state(problem, bar_size, **kwargs)
state = iterations(self, problem, state)
return self.output_from_state(state)

[docs]
def init_state(
self,
problem: gw_barycenter.GWBarycenterProblem,
bar_size: int,
bar_init: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray,
jnp.ndarray]]] = None,
a: Optional[jnp.ndarray] = None,
rng: Optional[jax.Array] = None,
) -> GWBarycenterState:
"""Initialize the (fused) Gromov-Wasserstein barycenter state.

Args:
problem: The barycenter problem.
bar_size: Size of the barycenter.
bar_init: Initial barycenter value. Can be one of the following:

- ``None`` - randomly initialize the barycenter.
- :class:`jax.numpy.ndarray` - barycenter cost matrix of shape
``[bar_size, bar_size]``.
Only used in the non-fused case.
- :class:`tuple` of :class:`jax.numpy.ndarray` - the first array
corresponds to a cost matrix of shape ``[bar_size, bar_size]``,
the second array is a ``[bar_size, ndim_fused]`` feature matrix used
in the fused case.

a: An array of shape ``[bar_size,]`` containing the barycenter weights.
rng: Random key for seeding used when ``bar_init = None``.

Returns:
The initial barycenter state.
"""
if a is None:
a = jnp.ones((bar_size,)) / bar_size
else:
assert a.shape == (bar_size,)

if bar_init is None:
rng = utils.default_prng_key(rng)
_, b = problem.segmented_y_b
rngs = jax.random.split(rng, problem.num_measures)

transports = init_transports(linear_solver, rngs, a, b, problem.epsilon)
x = problem.update_features(transports, a) if problem.is_fused else None
cost = problem.update_barycenter(transports, a)
else:
cost, x = bar_init if isinstance(bar_init, tuple) else (bar_init, None)
assert cost.shape == (bar_size, bar_size)
if problem.is_fused:
assert x is not None, "Barycenter features are not initialized."
assert x.shape == (bar_size, problem.ndim_fused)

num_iter = self.max_iterations
if self.store_inner_errors:
# TODO(michalk8): in the future, think about how to do this in general
errors = -jnp.ones((
))
else:
errors = None

costs = -jnp.ones((num_iter,))
costs_bary = -jnp.ones((num_iter, problem.num_measures))
gw_convergence = -jnp.ones((num_iter,))
return GWBarycenterState(
cost=cost,
x=x,
a=a,
errors=errors,
costs=costs,
costs_bary=costs_bary,
gw_convergence=gw_convergence
)

[docs]
def update_state(
self,
state: GWBarycenterState,
iteration: int,
problem: gw_barycenter.GWBarycenterProblem,
store_errors: bool = True,
) -> Tuple[float, bool, jnp.ndarray, Optional[jnp.ndarray]]:
"""Solve the (fused) Gromov-Wasserstein barycenter problem."""

def solve_gw(
state: GWBarycenterState, b: jnp.ndarray, y: jnp.ndarray,
f: Optional[jnp.ndarray]
) -> Tuple[float, bool, jnp.ndarray, Optional[jnp.ndarray]]:
quad_problem = problem._create_problem(state, y=y, b=b, f=f)
return (
out.reg_gw_cost, out.converged, out.matrix,
out.errors if store_errors else None
)

in_axes = [None, 0, 0]
in_axes += [0] if problem.is_fused else [None]
solve_fn = jax.vmap(solve_gw, in_axes=in_axes)

y, b = problem.segmented_y_b
y_f = problem.segmented_y_fused
costs, convergeds, transports, errors = solve_fn(state, b, y, y_f)

cost = jnp.sum(costs * problem.weights)
costs_bary = state.costs_bary.at[iteration].set(costs)
costs = state.costs.at[iteration].set(cost)

converged = jnp.all(convergeds)
gw_convergence = state.gw_convergence.at[iteration].set(converged)

if self.store_inner_errors:
errors = state.errors.at[iteration, ...].set(errors)
else:
errors = None

x = problem.update_features(
transports, state.a
) if problem.is_fused else state.x
cost = problem.update_barycenter(transports, state.a)
return state.set(
cost=cost,
x=x,
costs=costs,
costs_bary=costs_bary,
errors=errors,
gw_convergence=gw_convergence
)

[docs]
def output_from_state(self, state: GWBarycenterState) -> GWBarycenterState:
"""No-op."""
# TODO(michalk8): just for consistency with continuous barycenter
# will be refactored in the future to create an output
return state

def tree_flatten(self) -> Tuple[Sequence[Any], Dict[str, Any]]:  # noqa: D102
children, aux = super().tree_flatten()

@classmethod
def tree_unflatten(  # noqa: D102
cls, aux_data: Dict[str, Any], children: Sequence[Any]
) -> "GromovWassersteinBarycenter":
epsilon, _, threshold, quad_solver = children
return cls(
epsilon=epsilon,
threshold=threshold,
**aux_data,
)

@partial(jax.vmap, in_axes=[None, 0, None, 0, None])
def init_transports(
solver, rng: jax.Array, a: jnp.ndarray, b: jnp.ndarray,
epsilon: Optional[float]
) -> jnp.ndarray:
"""Initialize random 2D point cloud and solve the linear OT problem.

Args:
solver: Linear OT solver.
rng: Random key for seeding.
a: Source marginals (e.g., for barycenter) of shape ``[bar_size,]``.
b: Target marginals of shape ``[max_measure_size,]``.
epsilon: Entropy regularization.

Returns:
Transport map of shape ``[bar_size, max_measure_size]``.
"""
rng1, rng2 = jax.random.split(rng, 2)
x = jax.random.normal(rng1, shape=(len(a), 2))
y = jax.random.normal(rng2, shape=(len(b), 2))
geom = pointcloud.PointCloud(
)
problem = linear_problem.LinearProblem(geom, a=a, b=b)
return solver(problem).matrix

def iterations(  # noqa: D103
solver: GromovWassersteinBarycenter,
problem: gw_barycenter.GWBarycenterProblem, init_state: GWBarycenterState
) -> GWBarycenterState:

def cond_fn(
iteration: int, constants: GromovWassersteinBarycenter,
state: GWBarycenterState
) -> bool:
solver, _ = constants
return solver._continue(state, iteration)

def body_fn(
iteration, constants: Tuple[GromovWassersteinBarycenter,
gw_barycenter.GWBarycenterProblem],
state: GWBarycenterState, compute_error: bool
) -> GWBarycenterState:
del compute_error  # always assumed true
solver, problem = constants
return solver.update_state(state, iteration, problem)

return fixed_point_loop.fixpoint_iter(
cond_fn=cond_fn,
body_fn=body_fn,
min_iterations=solver.min_iterations,
max_iterations=solver.max_iterations,
inner_iterations=1,
constants=(solver, problem),
state=init_state,
)
```