ott.neural.networks.velocity_field.VelocityField#
- class ott.neural.networks.velocity_field.VelocityField(hidden_dims, output_dims, condition_dims=None, time_dims=None, time_encoder=<function cyclical_time_encoder>, act_fn=<PjitFunction of <function silu>>, dropout_rate=0.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#
Neural vector field.
This class learns a map \(v: \mathbb{R}\times \mathbb{R}^d \rightarrow \mathbb{R}^d\) solving the ODE \(\frac{dx}{dt} = v(t, x)\). Given a source distribution at time \(t_0\), the velocity field can be used to transport the source distribution given at \(t_0\) to a target distribution given at \(t_1\) by integrating \(v(t, x)\) from \(t=t_0\) to \(t=t_1\).
- Parameters:
hidden_dims (
Sequence
[int
]) – Dimensionality of the embedding of the data.output_dims (
Sequence
[int
]) – Dimensionality of the embedding of the output.condition_dims (
Optional
[Sequence
[int
]]) – Dimensionality of the embedding of the condition. IfNone
, the velocity field has no conditions.time_dims (
Optional
[Sequence
[int
]]) – Dimensionality of the time embedding. IfNone
,hidden_dims
is used.time_encoder (
Callable
[[Array
],Array
]) – Time encoder for the velocity field.dropout_rate (
float
)
Methods
act_fn
()SiLU (aka swish) activation function.
apply
(variables, *args[, rngs, method, ...])Applies a module method to variables and returns output and modified variables.
clone
(*[, parent, _deep_clone, _reset_names])Creates a clone of this Module, with optionally updated arguments.
copy
(*[, parent, name])Creates a copy of this Module, with optionally updated arguments.
create_train_state
(rng, optimizer, input_dim)Create the training state.
get_variable
(col, name[, default])Retrieves the value of a Variable.
has_rng
(name)Returns true if a PRNGSequence with name
name
exists.has_variable
(col, name)Checks if a variable of given collection and name exists in this Module.
init
(rngs, *args[, method, mutable, ...])Initializes a module method with variables and returns modified variables.
init_with_output
(rngs, *args[, method, ...])Initializes a module method with variables and returns output and modified variables.
Returns True if running under self.init(...) or nn.init(...)().
Returns true if the collection
col
is mutable.lazy_init
(rngs, *args[, method, mutable])Initializes a module without computing on an actual input.
make_rng
([name])Returns a new RNG key from a given RNG sequence for this Module.
param
(name, init_fn, *init_args[, unbox])Declares and returns a parameter in this Module.
perturb
(name, value[, collection])Add an zero-value variable ('perturbation') to the intermediate value.
put_variable
(col, name, value)Updates the value of the given variable if it is mutable, or an error otherwise.
setup
()Initializes a Module lazily (similar to a lazy
__init__
).sow
(col, name, value[, reduce_fn, init_fn])Stores a value in a collection.
time_encoder
([n_freqs])Encode time \(t\) into a cyclical representation.
unbind
()Returns an unbound copy of a Module and its variables.
variable
(col, name[, init_fn, unbox])Declares and returns a variable in this Module.
Attributes
Get the path of this Module.
Returns the variables in this module.