Source code for ott.utils

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import functools
import io
import warnings
from collections.abc import Sequence
from typing import (
    Any,
    Callable,
    List,
    NamedTuple,
    Optional,
    Tuple,
    TypeVar,
    Union,
)

try:
  from typing import ParamSpec
except ImportError:
  from typing_extensions import ParamSpec

import jax
import jax.numpy as jnp
import numpy as np
from jax.interpreters import batching

try:
  from tqdm import tqdm
except ImportError:
  tqdm = Any

__all__ = [
    "register_pytree_node",
    "deprecate",
    "default_prng_key",
    "default_progress_fn",
    "tqdm_progress_fn",
    "batched_vmap",
    "is_scalar",
]

IOStatus = Tuple[np.ndarray, np.ndarray, np.ndarray, NamedTuple]
IOCallback = Callable[[IOStatus], None]
P = ParamSpec("P")
R = TypeVar("R")


def register_pytree_node(cls: type) -> type:
  """Register dataclasses as pytree_nodes."""
  cls = dataclasses.dataclass()(cls)
  flatten = lambda obj: jax.tree_util.tree_flatten(dataclasses.asdict(obj))
  unflatten = lambda d, children: cls(**d.unflatten(children))
  jax.tree_util.register_pytree_node(cls, flatten, unflatten)
  return cls


def deprecate(  # noqa: D103
    *,
    version: Optional[str] = None,
    alt: Optional[str] = None,
    func: Optional[Callable[P, R]] = None
) -> Callable[P, R]:

  def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
    warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
    return func(*args, **kwargs)

  if func is None:
    return lambda fn: deprecate(version=version, alt=alt, func=fn)

  msg = f"`{func.__name__}` will be removed in the "
  msg += ("next" if version is None else f"`ott-jax=={version}`") + " release."
  if alt:
    msg += " " + alt

  return functools.wraps(func)(wrapper)


def default_prng_key(rng: Optional[jax.Array] = None) -> jax.Array:
  """Get the default PRNG key.

  Args:
    rng: PRNG key.

  Returns:
    If ``rng = None``, returns the default PRNG key.
    Otherwise, it returns the unmodified ``rng`` key.
  """
  return jax.random.key(0) if rng is None else rng


[docs] def default_progress_fn( fmt: str = "{iter} / {max_iter} -- {error}", stream: Optional[io.TextIOBase] = None, ) -> IOCallback: """Return a callback that prints the progress when solving :mod:`linear problems <ott.problems.linear>`. It prints the progress only when the error is computed, that is every :attr:`~ott.solvers.linear.sinkhorn.Sinkhorn.inner_iterations`. Args: fmt: Format used to print. It can format ``iter``, ``max_iter`` and ``error`` values. stream: Output IO stream. Returns: A callback function accepting the following arguments - the current iteration number, - the number of inner iterations after which the error is computed, - the total number of iterations, and - the current :class:`~ott.solvers.linear.sinkhorn.SinkhornState` or :class:`~ott.solvers.linear.sinkhorn_lr.LRSinkhornState`. Examples: .. code-block:: python import jax import jax.numpy as jnp from ott import utils from ott.geometry import pointcloud from ott.solvers import linear x = jax.random.normal(jax.random.key(0), (100, 5)) geom = pointcloud.PointCloud(x) progress_fn = utils.default_progress_fn() solve_fn = jax.jit(linear.solve, static_argnames=["progress_fn"]) out = solve_fn(geom, progress_fn=progress_fn) """ # noqa: D205 def progress_callback(status: IOStatus) -> None: iteration, inner_iterations, total_iter, errors = _prepare_info(status) # Avoid reporting error on each iteration, # because errors are only computed every `inner_iterations`. if iteration % inner_iterations == 0: error_idx = max(0, iteration // inner_iterations - 1) error = errors[error_idx] print( fmt.format(iter=iteration, max_iter=total_iter, error=error), file=stream ) return progress_callback
[docs] def tqdm_progress_fn( pbar: tqdm, fmt: str = "error: {error:0.6e}", ) -> IOCallback: """Return a callback that updates a progress bar when solving :mod:`linear problems <ott.problems.linear>`. It updates the progress bar only when the error is computed, that is every :attr:`~ott.solvers.linear.sinkhorn.Sinkhorn.inner_iterations`. Args: pbar: `tqdm <https://tqdm.github.io/docs/tqdm/>`_ progress bar. fmt: Format used for the postfix. It can format ``iter``, ``max_iter`` and ``error`` values. Returns: A callback function accepting the following arguments - the current iteration number, - the number of inner iterations after which the error is computed, - the total number of iterations, and - the current :class:`~ott.solvers.linear.sinkhorn.SinkhornState` or :class:`~ott.solvers.linear.sinkhorn_lr.LRSinkhornState`. Examples: .. code-block:: python import tqdm import jax import jax.numpy as jnp from ott import utils from ott.geometry import pointcloud from ott.solvers import linear x = jax.random.normal(jax.random.key(0), (100, 5)) geom = pointcloud.PointCloud(x) with tqdm.tqdm() as pbar: progress_fn = utils.tqdm_progress_fn(pbar) solve_fn = jax.jit(linear.solve, static_argnames=["progress_fn"]) out = solve_fn(geom, progress_fn=progress_fn) """ # noqa: D205 def progress_callback(status: IOStatus) -> None: iteration, inner_iterations, total_iter, errors = _prepare_info(status) # Avoid reporting error on each iteration, # because errors are only computed every `inner_iterations`. if iteration % inner_iterations == 0: error_idx = max(0, iteration // inner_iterations - 1) error = errors[error_idx] postfix = fmt.format(iter=iteration, max_iter=total_iter, error=error) pbar.set_postfix_str(postfix) pbar.total = total_iter // inner_iterations pbar.update() return progress_callback
def _prepare_info(status: IOStatus) -> Tuple[int, int, int, np.ndarray]: iteration, inner_iterations, total_iter, state = status iteration = int(iteration) + 1 inner_iterations = int(inner_iterations) total_iter = int(total_iter) errors = np.array(state.errors).ravel() return iteration, inner_iterations, total_iter, errors def _canonicalize_axis(axis: int, num_dims: int, *, has_extra_dim: bool) -> int: num_dims -= has_extra_dim if not -num_dims <= axis < num_dims: raise ValueError( f"axis {axis} is out of bounds for array of dimension {num_dims}" ) if axis < 0: axis = axis + num_dims return axis def _prepare_axes( name: str, leaves: Any, treedef: Any, axes: Any, *, return_flat: bool, has_extra_dim: bool, ) -> Any: axes = jax.api_util.flatten_axes(name, treedef, axes, kws=False) assert len(leaves) == len(axes), (len(leaves), len(axes)) axes = [ axis if axis is None else _canonicalize_axis(axis, jnp.ndim(leaf), has_extra_dim=has_extra_dim) for axis, leaf in zip(axes, leaves) ] return axes if return_flat else treedef.unflatten(axes) def _batch_and_remainder( args: Any, *, batch_size: int, in_axes: Optional[Union[int, Sequence[int], Any]], ) -> Tuple[Any, Any]: assert batch_size > 0, f"Batch size must be positive, got {batch_size}." leaves, treedef = jax.tree.flatten(args, is_leaf=batching.is_vmappable) in_axes = _prepare_axes( "vmap in_axes", leaves, treedef, in_axes, return_flat=True, has_extra_dim=False, ) assert not all( axis is None for axis in in_axes ), "vmap must have at least one non-None value in in_axes" num_splits = None has_scan, has_remainder = False, False scan_leaves, remainder_leaves = [], [] for leaf, axis in zip(leaves, in_axes): if axis is None: scan_leaf = remainder_leaf = leaf else: if num_splits is None: num_splits = leaf.shape[axis] // batch_size else: curr_num_splits = leaf.shape[axis] // batch_size assert num_splits == curr_num_splits, \ f"Expected {num_splits} splits, got {curr_num_splits}." num_elems = num_splits * batch_size scan_leaf = jax.lax.slice_in_dim(leaf, None, num_elems, axis=axis) new_shape = leaf.shape[:axis] + (-1, batch_size) + leaf.shape[axis + 1:] scan_leaf = scan_leaf.reshape(new_shape) remainder_leaf = jax.lax.slice_in_dim(leaf, num_elems, None, axis=axis) has_scan = has_scan or scan_leaf.shape[axis] has_remainder = has_remainder or remainder_leaf.shape[axis] scan_leaves.append(scan_leaf) remainder_leaves.append(remainder_leaf) scan_tree = treedef.unflatten(scan_leaves) if has_scan else None remainder_tree = treedef.unflatten( remainder_leaves ) if has_remainder else None return scan_tree, remainder_tree def _apply_scan( vmapped_fun: Callable[P, R], in_axes: Optional[Union[int, Sequence[int], Any]] ) -> Callable[P, R]: def num_steps(axes: List[Any], args: Tuple[Any, ...]) -> int: ix = next(ix for ix, axis in enumerate(axes) if axis is not None) leaf, *_ = jax.tree.leaves(args[ix]) axis, *_ = jax.tree.leaves(axes[ix]) return leaf.shape[axis] def select_batch(arg: Any, index: int, *, axis: int) -> Any: fn = functools.partial( jax.lax.dynamic_index_in_dim, index=index, axis=axis, keepdims=False ) return jax.tree.map(fn, arg) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: def body_fn(carry: None, index: int) -> Tuple[None, R]: del carry new_args = treedef.unflatten([ arg if axis is None else select_batch(arg, index, axis=axis) for arg, axis in zip(leaves, axes) ]) return None, vmapped_fun(*new_args, **kwargs) leaves, treedef = jax.tree.flatten(args, is_leaf=batching.is_vmappable) axes = _prepare_axes( "vmap in_axes", leaves, treedef, in_axes, return_flat=True, has_extra_dim=True, ) n = num_steps(axes, args) _, result = jax.lax.scan(body_fn, init=None, xs=jnp.arange(n)) return result return wrapper
[docs] def batched_vmap( fun: Callable[P, R], *, batch_size: int, in_axes: Optional[Union[int, Sequence[int], Any]] = 0, out_axes: Any = 0, ) -> Callable[P, R]: """Batched version of :func:`~jax.vmap`. Args: fun: Function to be mapped over additional axes. batch_size: Size of the batch. in_axes: Values specifying which input array axes to map over. out_axes: Values specifying where the mapped axis should appear in the output. Returns: The vectorized function. """ def unbatch(axis: int, x: jnp.ndarray) -> jnp.ndarray: x = jnp.moveaxis(x, 0, axis) return jax.lax.collapse(x, axis, axis + 2) def concat(axis: int, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: return jnp.concatenate([x, y], axis=axis) @functools.wraps(fun) def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: batched, remainder = _batch_and_remainder( args, batch_size=batch_size, in_axes=in_axes ) has_batched = batched is not None has_remainder = remainder is not None if has_batched: batched = batched_fun(*batched, **kwargs) leaves, treedef = jax.tree.flatten(batched, is_leaf=batching.is_vmappable) out_axes_ = _prepare_axes( "vmap out_axes", leaves, treedef, out_axes, return_flat=False, has_extra_dim=True, ) batched = jax.tree.map(unbatch, out_axes_, batched) if has_remainder: remainder = vmapped_fun(*remainder, **kwargs) if has_batched and has_remainder: return jax.tree.map(concat, out_axes_, batched, remainder) if has_batched: return batched # TODO(michalk8): check for empty arrays return remainder if isinstance(in_axes, list): in_axes = tuple(in_axes) vmapped_fun = jax.vmap(fun, in_axes=in_axes, out_axes=out_axes) batched_fun = _apply_scan(vmapped_fun, in_axes=in_axes) return wrapper
# TODO(michalk8): remove when `jax>=0.4.31` def is_scalar(x: Any) -> bool: # noqa: D103 if ( isinstance(x, (np.ndarray, jax.Array)) or hasattr(x, "__jax_array__") or np.isscalar(x) ): return jnp.asarray(x).ndim == 0 return False