ott.neural.networks.velocity_field.VelocityField

Contents

ott.neural.networks.velocity_field.VelocityField#

class ott.neural.networks.velocity_field.VelocityField(hidden_dims, output_dims, condition_dims=None, time_dims=None, time_encoder=<function cyclical_time_encoder>, act_fn=<PjitFunction of <function silu>>, dropout_rate=0.0, parent=<flax.linen.module._Sentinel object>, name=None)[source]#

Neural vector field.

This class learns a map \(v: \mathbb{R}\times \mathbb{R}^d \rightarrow \mathbb{R}^d\) solving the ODE \(\frac{dx}{dt} = v(t, x)\). Given a source distribution at time \(t_0\), the velocity field can be used to transport the source distribution given at \(t_0\) to a target distribution given at \(t_1\) by integrating \(v(t, x)\) from \(t=t_0\) to \(t=t_1\).

Parameters:

Methods

act_fn()

SiLU (aka swish) activation function.

apply(variables, *args[, rngs, method, ...])

Applies a module method to variables and returns output and modified variables.

clone(*[, parent, _deep_clone, _reset_names])

Creates a clone of this Module, with optionally updated arguments.

copy(*[, parent, name])

Creates a copy of this Module, with optionally updated arguments.

create_train_state(rng, optimizer, input_dim)

Create the training state.

get_variable(col, name[, default])

Retrieves the value of a Variable.

has_rng(name)

Returns true if a PRNGSequence with name name exists.

has_variable(col, name)

Checks if a variable of given collection and name exists in this Module.

init(rngs, *args[, method, mutable, ...])

Initializes a module method with variables and returns modified variables.

init_with_output(rngs, *args[, method, ...])

Initializes a module method with variables and returns output and modified variables.

is_initializing()

Returns True if running under self.init(...) or nn.init(...)().

is_mutable_collection(col)

Returns true if the collection col is mutable.

lazy_init(rngs, *args[, method, mutable])

Initializes a module without computing on an actual input.

make_rng([name])

Returns a new RNG key from a given RNG sequence for this Module.

param(name, init_fn, *init_args[, unbox])

Declares and returns a parameter in this Module.

perturb(name, value[, collection])

Add an zero-value variable ('perturbation') to the intermediate value.

put_variable(col, name, value)

Updates the value of the given variable if it is mutable, or an error otherwise.

setup()

Initializes a Module lazily (similar to a lazy __init__).

sow(col, name, value[, reduce_fn, init_fn])

Stores a value in a collection.

time_encoder([n_freqs])

Encode time \(t\) into a cyclical representation.

unbind()

Returns an unbound copy of a Module and its variables.

variable(col, name[, init_fn, unbox])

Declares and returns a variable in this Module.

Attributes

condition_dims

dropout_rate

name

parent

path

scope

time_dims

variables

Returns the variables in this module.

hidden_dims

output_dims