summaryrefslogtreecommitdiff
path: root/text_recognizer/data/base_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/base_dataset.py')
-rw-r--r--text_recognizer/data/base_dataset.py22
1 files changed, 18 insertions, 4 deletions
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index e08130d..b9567c7 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -6,6 +6,8 @@ import torch
from torch import Tensor
from torch.utils.data import Dataset
+from text_recognizer.data.transforms.load_transform import load_transform_from_file
+
@attr.s
class BaseDataset(Dataset):
@@ -21,8 +23,8 @@ class BaseDataset(Dataset):
data: Union[Sequence, Tensor] = attr.ib()
targets: Union[Sequence, Tensor] = attr.ib()
- transform: Optional[Callable] = attr.ib(default=None)
- target_transform: Optional[Callable] = attr.ib(default=None)
+ transform: Union[Optional[Callable], str] = attr.ib(default=None)
+ target_transform: Union[Optional[Callable], str] = attr.ib(default=None)
def __attrs_pre_init__(self) -> None:
"""Pre init constructor."""
@@ -32,19 +34,31 @@ class BaseDataset(Dataset):
"""Post init constructor."""
if len(self.data) != len(self.targets):
raise ValueError("Data and targets must be of equal length.")
+ self.transform = self._load_transform(self.transform)
+ self.target_transform = self._load_transform(self.target_transform)
+
+ @staticmethod
+ def _load_transform(
+ transform: Union[Optional[Callable], str]
+ ) -> Optional[Callable]:
+ if isinstance(transform, str):
+ return load_transform_from_file(transform)
+ return transform
def __len__(self) -> int:
"""Return the length of the dataset."""
return len(self.data)
- def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
+ def __getitem__(
+ self, index: int
+ ) -> Tuple[Union[Tensor, Tuple[Tensor, Tensor]], Tensor]:
"""Return a datum and its target, after processing by transforms.
Args:
index (int): Index of a datum in the dataset.
Returns:
- Tuple[Tensor, Tensor]: Datum and target pair.
+ Tuple[Union[Tensor, Tuple[Tensor, Tensor]], Tensor]: Datum and target pair.
"""
datum, target = self.data[index], self.targets[index]