ott.neural.networks.potentials.PotentialMLP.init_with_output

ott.neural.networks.potentials.PotentialMLP.init_with_output#

PotentialMLP.init_with_output(rngs, *args, method=None, mutable=DenyList(deny='intermediates'), capture_intermediates=False, **kwargs)#

Initializes a module method with variables and returns output and modified variables.

Parameters:
  • rngs (Array | dict[str, Array]) – The rngs for the variable collections.

  • *args – Named arguments passed to the init function.

  • method (Union[Callable[..., Any], str, None]) – An optional method. If provided, applies this method. If not provided, applies the __call__ method. A string can also be provided to specify a method by name.

  • mutable (Union[bool, str, Collection[str], DenyList]) – Can be bool, str, or list. Specifies which collections should be treated as mutable: bool: all/no collections are mutable. str: The name of a single mutable collection. list: A list of names of mutable collections. By default, all collections except “intermediates” are mutable.

  • capture_intermediates (bool | Callable[[Module, str], bool]) – If True, captures intermediate return values of all Modules inside the “intermediates” collection. By default only the return values of all __call__ methods are stored. A function can be passed to change the filter behavior. The filter function takes the Module instance and method name and returns a bool indicating whether the output of that method invocation should be stored.

  • **kwargs – Keyword arguments passed to the init function.

Return type:

tuple[Any, FrozenDict[str, Mapping[str, Any]] | dict[str, Any]]

Returns:

(output, vars), where vars are is a dict of the modified collections.