Source code for ott.neural.networks.layers.time_encoder
# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax.numpy as jnp
__all__ = ["cyclical_time_encoder"]
[docs]
def cyclical_time_encoder(t: jnp.ndarray, n_freqs: int = 128) -> jnp.ndarray:
r"""Encode time :math:`t` into a cyclical representation.
Time :math:`t` is encoded as :math:`cos(\hat{t})` and :math:`sin(\hat{t})`
where :math:`\hat{t} = [2\pi t, 2\pi 2 t,\dots, 2\pi n_f t]`.
Args:
t: Time of shape ``[n, 1]``.
n_freqs: Frequency :math:`n_f` of the cyclical encoding.
Returns:
Encoded time of shape ``[n, 2 * n_freqs]``.
"""
freq = 2 * jnp.arange(n_freqs) * jnp.pi
t = freq * t
return jnp.concatenate([jnp.cos(t), jnp.sin(t)], axis=-1)