diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:03:11 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:03:11 +0200 |
commit | 30e3ae483c846418b04ed48f014a4af2cf9a0771 (patch) | |
tree | 08a309e14416e68ad351be8a3d48bf50efd80d6b /text_recognizer/data/base_dataset.py | |
parent | ad3f404d36a9add32992698dd083d368f3b96812 (diff) |
Update transforms in datamodule/set
Diffstat (limited to 'text_recognizer/data/base_dataset.py')
-rw-r--r-- | text_recognizer/data/base_dataset.py | 22 |
1 files changed, 18 insertions, 4 deletions
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index e08130d..b9567c7 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -6,6 +6,8 @@ import torch from torch import Tensor from torch.utils.data import Dataset +from text_recognizer.data.transforms.load_transform import load_transform_from_file + @attr.s class BaseDataset(Dataset): @@ -21,8 +23,8 @@ class BaseDataset(Dataset): data: Union[Sequence, Tensor] = attr.ib() targets: Union[Sequence, Tensor] = attr.ib() - transform: Optional[Callable] = attr.ib(default=None) - target_transform: Optional[Callable] = attr.ib(default=None) + transform: Union[Optional[Callable], str] = attr.ib(default=None) + target_transform: Union[Optional[Callable], str] = attr.ib(default=None) def __attrs_pre_init__(self) -> None: """Pre init constructor.""" @@ -32,19 +34,31 @@ class BaseDataset(Dataset): """Post init constructor.""" if len(self.data) != len(self.targets): raise ValueError("Data and targets must be of equal length.") + self.transform = self._load_transform(self.transform) + self.target_transform = self._load_transform(self.target_transform) + + @staticmethod + def _load_transform( + transform: Union[Optional[Callable], str] + ) -> Optional[Callable]: + if isinstance(transform, str): + return load_transform_from_file(transform) + return transform def __len__(self) -> int: """Return the length of the dataset.""" return len(self.data) - def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: + def __getitem__( + self, index: int + ) -> Tuple[Union[Tensor, Tuple[Tensor, Tensor]], Tensor]: """Return a datum and its target, after processing by transforms. Args: index (int): Index of a datum in the dataset. Returns: - Tuple[Tensor, Tensor]: Datum and target pair. + Tuple[Union[Tensor, Tuple[Tensor, Tensor]], Tensor]: Datum and target pair. """ datum, target = self.data[index], self.targets[index] |