Source code for ott.neural.networks.velocity_field.ema
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
import jax.numpy as jnp
from flax import nnx
__all__ = ["EMA", "init_ema", "update_ema"]
[docs]
class EMA(nnx.Module):
"""Exponential moving average (EMA) of a model.
Args:
model: Model to average.
decay: EMA decay factor.
"""
def __init__(self, model: nnx.Module, *, decay: float):
super().__init__()
self.ema = init_ema(model)
self.decay = decay
def __call__(self, model: nnx.Module) -> None:
"""Update the EMA.
Args:
model: Model to average.
Returns:
Nothing, just updates the EMA in-place.
"""
update_ema(model, ema=self.ema, decay=self.decay)
[docs]
def init_ema(model: nnx.Module) -> nnx.Module:
"""Create initial exponential moving average (EMA) state.
Args:
model: Model to average.
Returns:
Copy of the model with parameters set to 0s.
"""
graphdef, state, rest = nnx.split(model, nnx.Param, ...)
ema_state = jax.tree.map(jnp.zeros_like, state)
# copy rest of the params, like RNGs, batch stats, etc.
rest = jax.tree.map(lambda r: r.copy(), rest)
return nnx.merge(graphdef, ema_state, rest)
[docs]
def update_ema(model: nnx.Module, *, ema: nnx.Module, decay: float) -> None:
"""Update the EMA of a model.
Args:
model: Model to average.
ema: EMA of the model.
decay: Decay factor.
Returns:
Nothing, just updates the EMA in-place.
"""
def update_fn(p_model: nnx.Param, p_ema: nnx.Param) -> nnx.Param:
return p_ema * decay + p_model * (1.0 - decay)
state, rest = nnx.state(model, nnx.Param, ...)
graphdef, ema_state, _ = nnx.split(ema, nnx.Param, ...)
rest = jax.tree.map(lambda r: r.copy(), rest)
ema_state = jax.tree.map(update_fn, state, ema_state)
nnx.update(ema, ema_state, rest)