diff options
Diffstat (limited to 'text_recognizer/data/base_dataset.py')
-rw-r--r-- | text_recognizer/data/base_dataset.py | 16 |
1 files changed, 2 insertions, 14 deletions
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index 4ceb818..b840bc8 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -5,8 +5,6 @@ import torch from torch import Tensor from torch.utils.data import Dataset -from text_recognizer.data.transforms.load_transform import load_transform_from_file - class BaseDataset(Dataset): r"""Base Dataset class that processes data and targets through optional transfroms. @@ -23,8 +21,8 @@ class BaseDataset(Dataset): self, data: Union[Sequence, Tensor], targets: Union[Sequence, Tensor], - transform: Union[Optional[Callable], str], - target_transform: Union[Optional[Callable], str], + transform: Callable, + target_transform: Callable, ) -> None: super().__init__() @@ -34,16 +32,6 @@ class BaseDataset(Dataset): self.target_transform = target_transform if len(self.data) != len(self.targets): raise ValueError("Data and targets must be of equal length.") - self.transform = self._load_transform(self.transform) - self.target_transform = self._load_transform(self.target_transform) - - @staticmethod - def _load_transform( - transform: Union[Optional[Callable], str] - ) -> Optional[Callable]: - if isinstance(transform, str): - return load_transform_from_file(transform) - return transform def __len__(self) -> int: """Return the length of the dataset.""" |