# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable
import jax
import jax.numpy as jnp
import numpy as np
__all__ = ["fixpoint_iter", "fixpoint_iter_backprop"]
[docs]def fixpoint_iter(
cond_fn: Callable[[int, Any, Any], bool],
body_fn: Callable[[Any, Any, Any, Any], Any], min_iterations: int,
max_iterations: int, inner_iterations: int, constants: Any, state: Any
):
"""Implementation of a fixed point loop.
This fixed point loop iterator applies body_fn to a tuple
(iteration, constants, state, compute_error) to output a new state, using
context provided in iteration and constants.
body_fn is iterated (inner_iterations -1) times, and one last time with the
compute_error flag indicating that additional computational effort can be
spent on recalculating the latest error (errors are stored as the first
element of the state tuple).
upon termination of these inner_iterations, the loop is continued if iteration
is smaller than min_iterations, stopped if equal/larger than max_iterations,
and interrupted if cond_fn returns False.
Args:
cond_fn : termination condition function
body_fn : body loop instructions
min_iterations : lower bound on the total amount of fixed point iterations
max_iterations : upper bound on the total amount of fixed point iterations
inner_iterations : number of iterations body_fn will be executed
successively before calling cond_fn.
constants : constant (during loop) parameters passed on to body
state : state variable
Returns:
outputs state returned by body_fn upon termination.
""" # noqa: D401
# If number of minimal iterations matches maximal number, force a scan instead
# of a while loop.
force_scan = (min_iterations == max_iterations)
compute_error_flags = jnp.arange(inner_iterations) == inner_iterations - 1
def max_cond_fn(iteration_state):
iteration, state = iteration_state
return jnp.logical_and(
iteration < max_iterations,
jnp.logical_or(
iteration < min_iterations, cond_fn(iteration, constants, state)
)
)
def unrolled_body_fn(iteration_state):
def one_iteration(iteration_state, compute_error):
iteration, state = iteration_state
state = body_fn(iteration, constants, state, compute_error)
iteration += 1
return (iteration, state), None
iteration_state, _ = jax.lax.scan(
one_iteration, iteration_state, compute_error_flags
)
return (iteration_state, None) if force_scan else iteration_state
if force_scan:
(_, state), _ = jax.lax.scan(
lambda carry, x: unrolled_body_fn(carry), (0, state),
None,
length=max_iterations // inner_iterations
)
else:
_, state = jax.lax.while_loop(max_cond_fn, unrolled_body_fn, (0, state))
return state
def fixpoint_iter_fwd(
cond_fn, body_fn, min_iterations, max_iterations, inner_iterations,
constants, state
):
"""Forward iteration of fixed point iteration to handle backpropagation.
The main difference with fixpoint_iter is the checkpointing, in variable
states, of the state variables as they are recorded through iterations, every
inner_iterations. This sequence of states will be used in the backward loop.
Args:
cond_fn : termination condition function
body_fn : body loop instructions
min_iterations : lower bound on the total amount of fixed point iterations
max_iterations : upper bound on the total amount of fixed point iterations
inner_iterations : number of iterations body_fn will be executed
successively before calling cond_fn.
constants : constant (during loop) parameters passed on to body
state : state variable
Returns:
outputs state returned by body_fn upon termination.
"""
force_scan = min_iterations == max_iterations
compute_error_flags = jnp.arange(inner_iterations) == inner_iterations - 1
states = jax.tree_util.tree_map(
lambda x: jnp.zeros(
(max_iterations // inner_iterations + 1,) + jnp.shape(x),
dtype=jax.dtypes.result_type(x)
), state
)
def max_cond_fn(iteration_states_state):
iteration, _, state = iteration_states_state
return jnp.logical_and(
iteration < max_iterations,
jnp.logical_or(
iteration < min_iterations, cond_fn(iteration, constants, state)
)
)
def unrolled_body_fn(iteration_states_state):
iteration, states, state = iteration_states_state
states = jax.tree_util.tree_map(
lambda states, state: jax.lax.dynamic_update_index_in_dim(
states, state, iteration // inner_iterations, 0
), states, state
)
def one_iteration(iteration_state, compute_error):
iteration, state = iteration_state
state = body_fn(iteration, constants, state, compute_error)
iteration += 1
return (iteration, state), None
iteration_state, _ = jax.lax.scan(
one_iteration, (iteration, state), compute_error_flags
)
iteration, state = iteration_state
out = (iteration, states, state)
return (out, None) if force_scan else out
if force_scan:
(iteration, states, state), _ = jax.lax.scan(
lambda carry, x: unrolled_body_fn(carry), (0, states, state),
None,
length=max_iterations // inner_iterations
)
else:
iteration, states, state = jax.lax.while_loop(
max_cond_fn, unrolled_body_fn, (0, states, state)
)
return state, (constants, iteration, states)
def fixpoint_iter_bwd(
cond_fn, body_fn, min_iterations, max_iterations, inner_iterations, res, g
):
"""Backward iteration of fixed point iteration, using checkpointed states."""
del cond_fn
force_scan = (min_iterations == max_iterations)
constants, iteration, states = res
# The tree may contain some python floats
g_constants = jax.tree_util.tree_map(
lambda x: jnp.zeros_like(x, dtype=x.dtype)
if isinstance(x, (np.ndarray, jnp.ndarray)) else 0, constants
)
def bwd_cond_fn(iteration_g_gconst):
iteration, _, _ = iteration_g_gconst
return iteration >= 0
def unrolled_body_fn_no_errors(iteration, constants, state):
compute_error_flags = jnp.zeros((inner_iterations,), dtype=bool)
def one_iteration(iteration_state, compute_error):
iteration, state = iteration_state
state = body_fn(iteration, constants, state, compute_error)
iteration += 1
return (iteration, state), None
iteration_state, _ = jax.lax.scan(
one_iteration, (iteration, state), compute_error_flags
)
_, state = iteration_state
return state
def unrolled_body_fn(iteration_g_gconst):
iteration, g, g_constants = iteration_g_gconst
state = jax.tree_util.tree_map(
lambda x: x[iteration // inner_iterations], states
)
_, pullback = jax.vjp(
unrolled_body_fn_no_errors, iteration, constants, state
)
_, gi_constants, g_state = pullback(g)
g_constants = jax.tree_util.tree_map(
lambda x, y: x + y, g_constants, gi_constants
)
out = (iteration - inner_iterations, g_state, g_constants)
return (out, None) if force_scan else out
if force_scan:
(_, g_state, g_constants), _ = jax.lax.scan(
lambda carry, x: unrolled_body_fn(carry), (0, g, g_constants),
None,
length=max_iterations // inner_iterations
)
else:
_, g_state, g_constants = jax.lax.while_loop(
bwd_cond_fn, unrolled_body_fn,
(iteration - inner_iterations, g, g_constants)
)
return g_constants, g_state
# definition of backprop friendly variant of fixpoint_iter.
fixpoint_iter_backprop = jax.custom_vjp(
fixpoint_iter, nondiff_argnums=(0, 1, 2, 3, 4)
)
fixpoint_iter_backprop.defvjp(fixpoint_iter_fwd, fixpoint_iter_bwd)