diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-27 01:44:49 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-27 01:44:49 +0200 |
commit | 9b8e14d89f0ef2508ed11f994f73af624155fe1d (patch) | |
tree | 10d9c138f4449646c8b5c0f95003ba86b831d04d /text_recognizer/data/base_dataset.py | |
parent | 63376b1c2da81c23ad5239f908b640cd42a514c7 (diff) |
Update data modules
Diffstat (limited to 'text_recognizer/data/base_dataset.py')
-rw-r--r-- | text_recognizer/data/base_dataset.py | 16 |
1 files changed, 2 insertions, 14 deletions
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index 4ceb818..b840bc8 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -5,8 +5,6 @@ import torch from torch import Tensor from torch.utils.data import Dataset -from text_recognizer.data.transforms.load_transform import load_transform_from_file - class BaseDataset(Dataset): r"""Base Dataset class that processes data and targets through optional transfroms. @@ -23,8 +21,8 @@ class BaseDataset(Dataset): self, data: Union[Sequence, Tensor], targets: Union[Sequence, Tensor], - transform: Union[Optional[Callable], str], - target_transform: Union[Optional[Callable], str], + transform: Callable, + target_transform: Callable, ) -> None: super().__init__() @@ -34,16 +32,6 @@ class BaseDataset(Dataset): self.target_transform = target_transform if len(self.data) != len(self.targets): raise ValueError("Data and targets must be of equal length.") - self.transform = self._load_transform(self.transform) - self.target_transform = self._load_transform(self.target_transform) - - @staticmethod - def _load_transform( - transform: Union[Optional[Callable], str] - ) -> Optional[Callable]: - if isinstance(transform, str): - return load_transform_from_file(transform) - return transform def __len__(self) -> int: """Return the length of the dataset.""" |