summaryrefslogtreecommitdiff
path: root/text_recognizer/data/base_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/base_dataset.py')
-rw-r--r--text_recognizer/data/base_dataset.py16
1 files changed, 2 insertions, 14 deletions
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index 4ceb818..b840bc8 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -5,8 +5,6 @@ import torch
from torch import Tensor
from torch.utils.data import Dataset
-from text_recognizer.data.transforms.load_transform import load_transform_from_file
-
class BaseDataset(Dataset):
r"""Base Dataset class that processes data and targets through optional transfroms.
@@ -23,8 +21,8 @@ class BaseDataset(Dataset):
self,
data: Union[Sequence, Tensor],
targets: Union[Sequence, Tensor],
- transform: Union[Optional[Callable], str],
- target_transform: Union[Optional[Callable], str],
+ transform: Callable,
+ target_transform: Callable,
) -> None:
super().__init__()
@@ -34,16 +32,6 @@ class BaseDataset(Dataset):
self.target_transform = target_transform
if len(self.data) != len(self.targets):
raise ValueError("Data and targets must be of equal length.")
- self.transform = self._load_transform(self.transform)
- self.target_transform = self._load_transform(self.target_transform)
-
- @staticmethod
- def _load_transform(
- transform: Union[Optional[Callable], str]
- ) -> Optional[Callable]:
- if isinstance(transform, str):
- return load_transform_from_file(transform)
- return transform
def __len__(self) -> int:
"""Return the length of the dataset."""