ott.neural.networks.velocity_field.ema.EMA

Contents

ott.neural.networks.velocity_field.ema.EMA#

class ott.neural.networks.velocity_field.ema.EMA(model, *, decay)[source]#

Exponential moving average (EMA) of a model.

Parameters:
  • model (Module) – Model to average.

  • decay (float) – EMA decay factor.

  • args (Any)

  • kwargs (Any)

Return type:

Any

Methods

eval(**attributes)

Sets the Module to evaluation mode.

iter_children()

Warning: this method is method is deprecated; use iter_children() instead.

iter_modules()

Warning: this method is method is deprecated; use iter_modules() instead.

perturb(name, value[, variable_type])

Extract gradients of intermediate values during training.

sow(variable_type, name, value[, reduce_fn, ...])

Store intermediate values during module execution for later extraction.

train(**attributes)

Sets the Module to training mode.