ott.neural.networks.potentials.BasePotential#
- class ott.neural.networks.potentials.BasePotential(parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Base class for the neural solver models (Linen).
Kept for backward compatibility with
ExpectileNeuralDualandMongeGapEstimator. New code should useBaseDualPotential(NNX) instead.Methods
apply(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
clone(*[, parent, _deep_clone, _reset_names])Creates a clone of this Module, with optionally updated arguments.
copy(*[, parent, name])Creates a copy of this Module, with optionally updated arguments.
create_train_state(rng, optimizer, input, ...)Create initial training state.
get_variable(col, name[, default])Retrieves the value of a Variable.
has_rng(name)Returns true if a PRNGSequence with name
nameexists.has_variable(col, name)Checks if a variable of given collection and name exists in this Module.
init(rngs, *args[, method, mutable, ...])Initializes a module method with variables and returns modified variables.
init_with_output(rngs, *args[, method, ...])Initializes a module method with variables and returns output and modified variables.
Returns True if running under self.init(...) or nn.init(...)().
Returns true if the collection
colis mutable.lazy_init(rngs, *args[, method, mutable])Initializes a module without computing on an actual input.
make_rng([name])Returns a new RNG key from a given RNG sequence for this Module.
param(name, init_fn, *init_args[, unbox])perturb(name, value[, collection])Add an zero-value variable ('perturbation') to the intermediate value.
potential_gradient_fn(params)Return a function returning a vector or the gradient of the potential.
potential_value_fn(params[, ...])Return a function giving the value of the potential.
put_variable(col, name, value)Updates the value of the given variable if it is mutable, or an error otherwise.
setup()Initializes a Module lazily (similar to a lazy
__init__).sow(col, name, value[, reduce_fn, init_fn])unbind()Returns an unbound copy of a Module and its variables.
variable(col, name[, init_fn, unbox])Attributes