diff options
Diffstat (limited to 'text_recognizer/data/base_dataset.py')
-rw-r--r-- | text_recognizer/data/base_dataset.py | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index c26f1c9..8640d92 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -1,5 +1,5 @@ """Base PyTorch Dataset class.""" -from typing import Any, Callable, Dict, Sequence, Tuple, Union +from typing import Callable, Dict, Optional, Sequence, Tuple, Union import attr import torch @@ -22,14 +22,13 @@ class BaseDataset(Dataset): data: Union[Sequence, Tensor] = attr.ib() targets: Union[Sequence, Tensor] = attr.ib() - transform: Callable = attr.ib() - target_transform: Callable = attr.ib() + transform: Optional[Callable] = attr.ib(default=None) + target_transform: Optional[Callable] = attr.ib(default=None) def __attrs_pre_init__(self) -> None: super().__init__() def __attrs_post_init__(self) -> None: - # TODO: refactor this if len(self.data) != len(self.targets): raise ValueError("Data and targets must be of equal length.") @@ -37,14 +36,14 @@ class BaseDataset(Dataset): """Return the length of the dataset.""" return len(self.data) - def __getitem__(self, index: int) -> Tuple[Any, Any]: + def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: """Return a datum and its target, after processing by transforms. Args: index (int): Index of a datum in the dataset. Returns: - Tuple[Any, Any]: Datum and target pair. + Tuple[Tensor, Tensor]: Datum and target pair. """ datum, target = self.data[index], self.targets[index] |