Source code for ott.math.fixed_point_loop

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable

import jax
import jax.numpy as jnp
import numpy as np

__all__ = ["fixpoint_iter", "fixpoint_iter_backprop"]


[docs] def fixpoint_iter( cond_fn: Callable[[int, Any, Any], bool], body_fn: Callable[[Any, Any, Any, Any], Any], min_iterations: int, max_iterations: int, inner_iterations: int, constants: Any, state: Any ): """Implementation of a fixed point loop. This fixed point loop iterator applies ``body_fn`` to a tuple ``(iteration, constants, state, compute_error)`` to output a new state, using context provided in iteration and constants. ``body_fn`` is iterated (inner_iterations -1) times, and one last time with the ``compute_error`` flag to ``True``, indicating that additional computational effort can be spent on recalculating the latest error (``errors`` are stored as the first element of the state tuple). upon termination of these ``inner_iterations``, the loop is continued if iteration is smaller than ``min_iterations``, stopped if equal/larger than ``max_iterations``, and interrupted if ``cond_fn`` returns False. Args: cond_fn : termination condition function body_fn : body loop instructions min_iterations : lower bound on the total amount of fixed point iterations max_iterations : upper bound on the total amount of fixed point iterations inner_iterations : number of iterations ``body_fn`` will be executed successively before calling ``cond_fn``. constants : constant (during loop) parameters passed on to body state : state variable Returns: outputs state returned by ``body_fn`` upon termination. """ # noqa: D401 # If number of minimal iterations matches maximal number, force a scan instead # of a while loop. force_scan = (min_iterations == max_iterations) compute_error_flags = jnp.arange(inner_iterations) == inner_iterations - 1 def max_cond_fn(iteration_state): iteration, state = iteration_state return jnp.logical_and( iteration < max_iterations, jnp.logical_or( iteration < min_iterations, cond_fn(iteration, constants, state) ) ) def unrolled_body_fn(iteration_state): def one_iteration(iteration_state, compute_error): iteration, state = iteration_state state = body_fn(iteration, constants, state, compute_error) iteration += 1 return (iteration, state), None iteration_state, _ = jax.lax.scan( one_iteration, iteration_state, compute_error_flags ) return (iteration_state, None) if force_scan else iteration_state if force_scan: (_, state), _ = jax.lax.scan( lambda carry, x: unrolled_body_fn(carry), (0, state), None, length=max_iterations // inner_iterations ) else: _, state = jax.lax.while_loop(max_cond_fn, unrolled_body_fn, (0, state)) return state
def fixpoint_iter_fwd( cond_fn, body_fn, min_iterations, max_iterations, inner_iterations, constants, state ): """Forward iteration of fixed point iteration to handle backpropagation. The main difference with fixpoint_iter is the checkpointing, in variable states, of the state variables as they are recorded through iterations, every inner_iterations. This sequence of states will be used in the backward loop. Args: cond_fn : termination condition function body_fn : body loop instructions min_iterations : lower bound on the total amount of fixed point iterations max_iterations : upper bound on the total amount of fixed point iterations inner_iterations : number of iterations body_fn will be executed successively before calling cond_fn. constants : constant (during loop) parameters passed on to body state : state variable Returns: outputs state returned by body_fn upon termination. """ force_scan = min_iterations == max_iterations compute_error_flags = jnp.arange(inner_iterations) == inner_iterations - 1 states = jax.tree_util.tree_map( lambda x: jnp.zeros( (max_iterations // inner_iterations + 1,) + jnp.shape(x), dtype=jax.dtypes.result_type(x) ), state ) def max_cond_fn(iteration_states_state): iteration, _, state = iteration_states_state return jnp.logical_and( iteration < max_iterations, jnp.logical_or( iteration < min_iterations, cond_fn(iteration, constants, state) ) ) def unrolled_body_fn(iteration_states_state): iteration, states, state = iteration_states_state states = jax.tree_util.tree_map( lambda states, state: jax.lax.dynamic_update_index_in_dim( states, state, iteration // inner_iterations, 0 ), states, state ) def one_iteration(iteration_state, compute_error): iteration, state = iteration_state state = body_fn(iteration, constants, state, compute_error) iteration += 1 return (iteration, state), None iteration_state, _ = jax.lax.scan( one_iteration, (iteration, state), compute_error_flags ) iteration, state = iteration_state out = (iteration, states, state) return (out, None) if force_scan else out if force_scan: (iteration, states, state), _ = jax.lax.scan( lambda carry, x: unrolled_body_fn(carry), (0, states, state), None, length=max_iterations // inner_iterations ) else: iteration, states, state = jax.lax.while_loop( max_cond_fn, unrolled_body_fn, (0, states, state) ) return state, (constants, iteration, states) def fixpoint_iter_bwd( cond_fn, body_fn, min_iterations, max_iterations, inner_iterations, res, g ): """Backward iteration of fixed point iteration, using checkpointed states.""" del cond_fn force_scan = (min_iterations == max_iterations) constants, iteration, states = res # The tree may contain some python floats g_constants = jax.tree_util.tree_map( lambda x: jnp.zeros_like(x, dtype=x.dtype) if isinstance(x, (np.ndarray, jnp.ndarray)) else 0, constants ) def bwd_cond_fn(iteration_g_gconst): iteration, _, _ = iteration_g_gconst return iteration >= 0 def unrolled_body_fn_no_errors(iteration, constants, state): compute_error_flags = jnp.zeros((inner_iterations,), dtype=bool) def one_iteration(iteration_state, compute_error): iteration, state = iteration_state state = body_fn(iteration, constants, state, compute_error) iteration += 1 return (iteration, state), None iteration_state, _ = jax.lax.scan( one_iteration, (iteration, state), compute_error_flags ) _, state = iteration_state return state def unrolled_body_fn(iteration_g_gconst): iteration, g, g_constants = iteration_g_gconst state = jax.tree_util.tree_map( lambda x: x[iteration // inner_iterations], states ) _, pullback = jax.vjp( unrolled_body_fn_no_errors, iteration, constants, state ) _, gi_constants, g_state = pullback(g) g_constants = jax.tree_util.tree_map( lambda x, y: x + y, g_constants, gi_constants ) out = (iteration - inner_iterations, g_state, g_constants) return (out, None) if force_scan else out if force_scan: (_, g_state, g_constants), _ = jax.lax.scan( lambda carry, x: unrolled_body_fn(carry), (0, g, g_constants), None, length=max_iterations // inner_iterations ) else: _, g_state, g_constants = jax.lax.while_loop( bwd_cond_fn, unrolled_body_fn, (iteration - inner_iterations, g, g_constants) ) return g_constants, g_state # definition of backprop friendly variant of fixpoint_iter. fixpoint_iter_backprop = jax.custom_vjp( fixpoint_iter, nondiff_argnums=(0, 1, 2, 3, 4) ) fixpoint_iter_backprop.defvjp(fixpoint_iter_fwd, fixpoint_iter_bwd)