Source code for ott.neural.datasets

# Copyright OTT-JAX
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import dataclasses
from typing import Any, Dict, Optional, Sequence

import numpy as np

__all__ = ["OTData", "OTDataset"]

Item_t = Dict[str, np.ndarray]


[docs] @dataclasses.dataclass(repr=False, frozen=True) class OTData: """Distribution data for (conditional) optimal transport problems. Args: lin: Linear term of the samples. quad: Quadratic term of the samples. condition: Condition corresponding to the data distribution. """ lin: Optional[np.ndarray] = None quad: Optional[np.ndarray] = None condition: Optional[np.ndarray] = None def __getitem__(self, ix: int) -> Item_t: return {k: v[ix] for k, v in self.__dict__.items() if v is not None} def __len__(self) -> int: if self.lin is not None: return len(self.lin) if self.quad is not None: return len(self.quad) return 0
[docs] class OTDataset: """Dataset for optimal transport problems. Args: src_data: Samples from the source distribution. tgt_data: Samples from the target distribution. src_conditions: Conditions for the source data. tgt_conditions: Conditions for the target data. is_aligned: Whether the samples from the source and the target data are paired. If yes, the source and the target conditions must match. seed: Random seed used to match source and target when not aligned. """ SRC_PREFIX = "src" TGT_PREFIX = "tgt" def __init__( self, src_data: OTData, tgt_data: OTData, src_conditions: Optional[Sequence[Any]] = None, tgt_conditions: Optional[Sequence[Any]] = None, is_aligned: bool = False, seed: Optional[int] = None, ): self.src_data = src_data self.tgt_data = tgt_data if src_conditions is None: src_conditions = [None] * len(src_data) self.src_conditions = list(src_conditions) if tgt_conditions is None: tgt_conditions = [None] * len(tgt_data) self.tgt_conditions = list(tgt_conditions) self._tgt_cond_to_ix = collections.defaultdict(list) for ix, cond in enumerate(tgt_conditions): self._tgt_cond_to_ix[cond].append(ix) self.is_aligned = is_aligned self._rng = np.random.default_rng(seed) self._verify_integrity() def _verify_integrity(self) -> None: assert len(self.src_data) == len(self.src_conditions) assert len(self.tgt_data) == len(self.tgt_conditions) if self.is_aligned: assert len(self.src_data) == len(self.tgt_data) assert self.src_conditions == self.tgt_conditions else: sym_diff = set(self.src_conditions ).symmetric_difference(self.tgt_conditions) assert not sym_diff, sym_diff def _sample_from_target(self, src_ix: int) -> Item_t: src_cond = self.src_conditions[src_ix] tgt_ixs = self._tgt_cond_to_ix[src_cond] ix = self._rng.choice(tgt_ixs) return self.tgt_data[ix] def __getitem__(self, ix: int) -> Item_t: src = self.src_data[ix] src = {f"{self.SRC_PREFIX}_{k}": v for k, v in src.items()} tgt = self.tgt_data[ix] if self.is_aligned else self._sample_from_target(ix) tgt = {f"{self.TGT_PREFIX}_{k}": v for k, v in tgt.items()} return {**src, **tgt} def __len__(self) -> int: return len(self.src_data)