Source code for ott.solvers.linear.continuous_barycenter

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import Any, NamedTuple, Optional, Tuple, Union

import jax
import jax.numpy as jnp

from ott import utils
from ott.geometry import pointcloud
from ott.math import fixed_point_loop
from ott.math import utils as mu
from ott.problems.linear import barycenter_problem, linear_problem
from ott.solvers import was_solver
from ott.solvers.linear import sinkhorn, sinkhorn_lr

__all__ = ["FreeBarycenterState", "FreeWassersteinBarycenter"]

LinearOutput = Union[sinkhorn.SinkhornOutput, sinkhorn_lr.LRSinkhornOutput]


[docs] class FreeBarycenterState(NamedTuple): """Holds the state of the Wasserstein barycenter solver. Args: x: Barycenter points. a: Barycenter weights. costs: Holds the sequence of regularized costs seen through the outer loop of the solver. linear_convergence: Holds the sequence of bool convergence flags of the inner Sinkhorn iterations. linear_outputs: Holds latest output objects returned by linear solver when recomputing transport between barycenter vs. each of the input measures. errors: Holds sequence of vectors of errors of the Sinkhorn algorithm at each iteration. """ x: jnp.ndarray a: jnp.ndarray costs: Optional[jnp.ndarray] = None linear_convergence: Optional[jnp.ndarray] = None linear_outputs: Optional[LinearOutput] = None errors: Optional[jnp.ndarray] = None
[docs] def set(self, **kwargs: Any) -> "FreeBarycenterState": """Return a copy of self, possibly with overwrites.""" return self._replace(**kwargs)
[docs] def update( self, iteration: int, bar_prob: barycenter_problem.FreeBarycenterProblem, linear_solver: Any, store_errors: bool ) -> "FreeBarycenterState": """Update the state of the solver. Args: iteration: the current iteration of the outer loop. bar_prob: the barycenter problem. linear_solver: the linear OT solver to use. store_errors: whether to store the errors of the inner loop. Returns: The updated state. """ seg_y, seg_b, _ = bar_prob.segmented_y_b @functools.partial(jax.vmap, in_axes=[None, None, 0, 0]) def solve_linear_ot( a: Optional[jnp.ndarray], x: jnp.ndarray, b: jnp.ndarray, y: jnp.ndarray ): geom = pointcloud.PointCloud( x, y, cost_fn=bar_prob.cost_fn, epsilon=bar_prob.epsilon ) prob = linear_problem.LinearProblem(geom, a=a, b=b) out = linear_solver(prob) # instantiate matrix since it is a property of out. return out, out.matrix outs, matrices = solve_linear_ot(self.a, self.x, seg_b, seg_y) reg_ot_costs = outs.reg_ot_cost convergeds = outs.converged errors = outs.errors cost = jnp.sum(reg_ot_costs * bar_prob.weights) updated_costs = self.costs.at[iteration].set(cost) converged = jnp.all(convergeds) linear_convergence = self.linear_convergence.at[iteration].set(converged) if store_errors and self.errors is not None: errors = self.errors.at[iteration, :, :].set(errors) else: errors = None # Approximation of barycenter as barycenter of barycenters per measure. barycenters_per_measure = mu.barycentric_projection( matrices, seg_y, bar_prob.cost_fn ) x_new = jax.vmap( lambda w, y: bar_prob.cost_fn.barycenter(w, y)[0], in_axes=[None, 1] )(bar_prob.weights, barycenters_per_measure) return self.set( x=x_new, costs=updated_costs, linear_convergence=linear_convergence, linear_outputs=outs, errors=errors )
class FreeBarycenterOutput(NamedTuple): """Holds the output of a Free Wasserstein barycenter solver. Objects of this class contain both solutions and problem definition of a regularized Free Barycenter OT problem, along several methods that can be used to access its content. Args: x: barycenter points. a: barycenter weights. costs: Holds the sequence of regularized GW costs seen through the outer loop of the solver. linear_convergence: Holds the sequence of bool convergence flags of the inner Sinkhorn iterations. linear_outputs: Holds latest output objects returned by linear solver when recomputing transport between barycenter vs. each of the input measures. errors: Holds sequence of vectors of errors of the Sinkhorn algorithm at each iteration. """ x: jnp.ndarray a: jnp.ndarray bar_prob: barycenter_problem.FreeBarycenterProblem costs: jnp.ndarray linear_convergence: jnp.ndarray linear_outputs: LinearOutput errors: Optional[jnp.ndarray] = None @property def all_linear_solvers_converged(self) -> bool: """Whether all linear convergence flags converged.""" return jnp.all(self.linear_convergence[self.linear_convergence != -1]) def matrix_at_index(self, measure_index: int) -> jnp.ndarray: """Return the transport matrix from barycenter to measure_index measure.""" size_measure = self.bar_prob.num_per_measure[measure_index] matrix = self.linear_output_at_index(measure_index).matrix return jax.lax.dynamic_slice_in_dim(matrix, 0, size_measure, axis=-1) @property def bar_size(self) -> int: """Size of the barycenter.""" return self.x.shape[0] @property def num_iters(self) -> int: """Number of outer iterations performed to converge.""" return jnp.sum(self.linear_convergence != -1) @property def costs_along_iterations(self) -> jnp.ndarray: """Costs vector with superfluous values removed.""" return self.costs[:self.num_iters] def linear_output_at_index(self, i: int) -> LinearOutput: """Linear solver output of transport from barycenter to measure i.""" return jax.tree.map(lambda x: x[i], self.linear_outputs)
[docs] @jax.tree_util.register_pytree_node_class class FreeWassersteinBarycenter(was_solver.WassersteinSolver): """Continuous Wasserstein barycenter solver :cite:`cuturi:14`.""" def __call__( # noqa: D102 self, bar_prob: barycenter_problem.FreeBarycenterProblem, bar_size: int = 100, x_init: Optional[jnp.ndarray] = None, rng: Optional[jax.Array] = None, ) -> FreeBarycenterState: rng = utils.default_prng_key(rng) return self.iterations(bar_size, bar_prob, x_init, rng)
[docs] def init_state( self, bar_prob: barycenter_problem.FreeBarycenterProblem, bar_size: int, x_init: Optional[jnp.ndarray] = None, rng: Optional[jax.Array] = None, ) -> FreeBarycenterState: """Initialize the state of the Wasserstein barycenter iterations. Args: bar_prob: The barycenter problem. bar_size: Size of the barycenter. x_init: Initial barycenter estimate of shape ``[bar_size, ndim]``. If `None`, ``bar_size`` points will be sampled from the input measures according to their weights :attr:`~ott.problems.linear.barycenter_problem.FreeBarycenterProblem.flattened_y`. rng: Random key for seeding. Returns: The initial barycenter state. """ if x_init is not None: assert bar_size == x_init.shape[0] x = x_init else: # sample randomly points in the support of the y measures rng = utils.default_prng_key(rng) indices_subset = jax.random.choice( rng, a=bar_prob.flattened_y.shape[0], shape=(bar_size,), replace=False, p=bar_prob.flattened_b ) x = bar_prob.flattened_y[indices_subset, :] # TODO(cuturi) expand to non-uniform weights for barycenter. a = jnp.ones((bar_size,)) / bar_size num_iter = self.max_iterations if self.store_inner_errors: errors = -jnp.ones(( num_iter, bar_prob.num_measures, self.linear_solver.outer_iterations )) else: errors = None state = FreeBarycenterState( x=x, a=a, costs=-jnp.ones((num_iter,)), linear_convergence=-jnp.ones((num_iter,)), errors=errors ) abstract_tree = jax.eval_shape( functools.partial(state.update, store_errors=self.store_inner_errors), 0, bar_prob, self.linear_solver ) linear_outputs = jax.tree.map(jnp.zeros_like, abstract_tree.linear_outputs) return FreeBarycenterState( x=x, a=a, costs=-jnp.ones((num_iter,)), linear_convergence=-jnp.ones((num_iter,)), linear_outputs=linear_outputs, errors=errors )
[docs] def output_from_state( # noqa: D102 self, state: FreeBarycenterState, bar_prob: barycenter_problem.FreeBarycenterProblem ) -> FreeBarycenterOutput: """Create an output from a barycenter state.""" return FreeBarycenterOutput( x=state.x, a=state.a, bar_prob=bar_prob, costs=state.costs, linear_convergence=state.linear_convergence, linear_outputs=state.linear_outputs, errors=state.errors, )
[docs] def iterations( self, bar_size: int, bar_prob: barycenter_problem.FreeBarycenterProblem, x_init: jnp.ndarray, rng: jax.Array ) -> FreeBarycenterState: """Wasserstein barycenter outer loop.""" def cond_fn( iteration: int, constants: Tuple[FreeWassersteinBarycenter, barycenter_problem.FreeBarycenterProblem], state: FreeBarycenterState ) -> bool: return self._continue(state, iteration) def body_fn( iteration, constants: Tuple[FreeWassersteinBarycenter, barycenter_problem.FreeBarycenterProblem], state: FreeBarycenterState, compute_error: bool ) -> FreeBarycenterState: del compute_error # Always assumed True bar_prob = constants return state.update( iteration, bar_prob, self.linear_solver, self.store_inner_errors ) state = fixed_point_loop.fixpoint_iter( cond_fn=cond_fn, body_fn=body_fn, min_iterations=self.min_iterations, max_iterations=self.max_iterations, inner_iterations=1, constants=bar_prob, state=self.init_state(bar_prob, bar_size, x_init, rng) ) return self.output_from_state(state, bar_prob)