Source code for ott.solvers.linear.continuous_barycenter

# Copyright OTT-JAX
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
from typing import Any, NamedTuple, Optional, Tuple

import jax
import jax.numpy as jnp

from ott.geometry import pointcloud
from ott.math import fixed_point_loop
from ott.math import utils as mu
from ott.problems.linear import barycenter_problem, linear_problem
from ott.solvers import was_solver

__all__ = ["FreeBarycenterState", "FreeWassersteinBarycenter"]

[docs]class FreeBarycenterState(NamedTuple): """Holds the state of the Wasserstein barycenter solver. Args: costs: Holds the sequence of regularized GW costs seen through the outer loop of the solver. linear_convergence: Holds the sequence of bool convergence flags of the inner Sinkhorn iterations. errors: Holds sequence of vectors of errors of the Sinkhorn algorithm at each iteration. x: barycenter points. a: barycenter weights. """ costs: Optional[jnp.ndarray] = None linear_convergence: Optional[jnp.ndarray] = None errors: Optional[jnp.ndarray] = None x: Optional[jnp.ndarray] = None a: Optional[jnp.ndarray] = None
[docs] def set(self, **kwargs: Any) -> "FreeBarycenterState": """Return a copy of self, possibly with overwrites.""" return self._replace(**kwargs)
[docs] def update( self, iteration: int, bar_prob: barycenter_problem.FreeBarycenterProblem, linear_ot_solver: Any, store_errors: bool ) -> "FreeBarycenterState": """Update the state of the solver. Args: iteration: the current iteration of the outer loop. bar_prob: the barycenter problem. linear_ot_solver: the linear OT solver to use. store_errors: whether to store the errors of the inner loop. Returns: The updated state. """ seg_y, seg_b = bar_prob.segmented_y_b @functools.partial(jax.vmap, in_axes=[None, None, 0, 0]) def solve_linear_ot( a: Optional[jnp.ndarray], x: jnp.ndarray, b: jnp.ndarray, y: jnp.ndarray ): out = linear_ot_solver( linear_problem.LinearProblem( pointcloud.PointCloud( x, y, src_mask=a > 0., tgt_mask=b > 0., cost_fn=bar_prob.cost_fn, epsilon=bar_prob.epsilon ), a, b ) ) return ( out.reg_ot_cost, out.converged, out.matrix, out.errors if store_errors else None ) if bar_prob.debiased: raise NotImplementedError( "Debiased version of continuous Wasserstein barycenter " "not yet implemented." ) reg_ot_costs, convergeds, matrices, errors = solve_linear_ot( self.a, self.x, seg_b, seg_y ) cost = jnp.sum(reg_ot_costs * bar_prob.weights) updated_costs =[iteration].set(cost) converged = jnp.all(convergeds) linear_convergence =[iteration].set(converged) if store_errors and self.errors is not None: errors =[iteration, :, :].set(errors) else: errors = None # Approximation of barycenter as barycenter of barycenters per measure. barycenters_per_measure = mu.barycentric_projection( matrices, seg_y, bar_prob.cost_fn ) x_new = jax.vmap( bar_prob.cost_fn.barycenter, in_axes=[None, 1] )(bar_prob.weights, barycenters_per_measure) return self.set( costs=updated_costs, linear_convergence=linear_convergence, errors=errors, x=x_new )
[docs]@jax.tree_util.register_pytree_node_class class FreeWassersteinBarycenter(was_solver.WassersteinSolver): """Continuous Wasserstein barycenter solver :cite:`cuturi:14`.""" def __call__( # noqa: D102 self, bar_prob: barycenter_problem.FreeBarycenterProblem, bar_size: int = 100, x_init: Optional[jnp.ndarray] = None, rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0) ) -> FreeBarycenterState: # TODO(michalk8): no reason for iterations to be outside this class return iterations(self, bar_size, bar_prob, x_init, rng)
[docs] def init_state( self, bar_prob: barycenter_problem.FreeBarycenterProblem, bar_size: int, x_init: Optional[jnp.ndarray] = None, rng: jax.random.PRNGKeyArray = jax.random.PRNGKey(0), ) -> FreeBarycenterState: """Initialize the state of the Wasserstein barycenter iterations. Args: bar_prob: The barycenter problem. bar_size: Size of the barycenter. x_init: Initial barycenter estimate of shape ``[bar_size, ndim]``. If `None`, ``bar_size`` points will be sampled from the input measures according to their weights :attr:`~ott.problems.linear.barycenter_problem.FreeBarycenterProblem.flattened_y`. rng: Random key for seeding. Returns: The initial barycenter state. """ if x_init is not None: assert bar_size == x_init.shape[0] x = x_init else: # sample randomly points in the support of the y measures indices_subset = jax.random.choice( rng, a=bar_prob.flattened_y.shape[0], shape=(bar_size,), replace=False, p=bar_prob.flattened_b ) x = bar_prob.flattened_y[indices_subset, :] # TODO(cuturi) expand to non-uniform weights for barycenter. a = jnp.ones((bar_size,)) / bar_size num_iter = self.max_iterations if self.store_inner_errors: errors = -jnp.ones(( num_iter, bar_prob.num_measures, self.linear_ot_solver.outer_iterations )) else: errors = None return FreeBarycenterState( -jnp.ones((num_iter,)), -jnp.ones((num_iter,)), errors, x, a )
[docs] def output_from_state( # noqa: D102 self, state: FreeBarycenterState ) -> FreeBarycenterState: # TODO(michalk8): create an output variable to match rest of the framework return state
def iterations( solver: FreeWassersteinBarycenter, bar_size: int, bar_prob: barycenter_problem.FreeBarycenterProblem, x_init: jnp.ndarray, rng: jax.random.PRNGKeyArray ) -> FreeBarycenterState: """Jittable Wasserstein barycenter outer loop.""" def cond_fn( iteration: int, constants: Tuple[FreeWassersteinBarycenter, barycenter_problem.FreeBarycenterProblem], state: FreeBarycenterState ) -> bool: solver, _ = constants return solver._continue(state, iteration) def body_fn( iteration, constants: Tuple[FreeWassersteinBarycenter, barycenter_problem.FreeBarycenterProblem], state: FreeBarycenterState, compute_error: bool ) -> FreeBarycenterState: del compute_error # Always assumed True solver, bar_prob = constants return state.update( iteration, bar_prob, solver.linear_ot_solver, solver.store_inner_errors ) state = fixed_point_loop.fixpoint_iter( cond_fn=cond_fn, body_fn=body_fn, min_iterations=solver.min_iterations, max_iterations=solver.max_iterations, inner_iterations=1, constants=(solver, bar_prob), state=solver.init_state(bar_prob, bar_size, x_init, rng) ) return solver.output_from_state(state)