diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-02 21:13:48 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-02 21:13:48 +0200 |
commit | 75801019981492eedf9280cb352eea3d8e99b65f (patch) | |
tree | 6521cc4134459e42591b2375f70acd348741474e /text_recognizer/data/base_dataset.py | |
parent | e5eca28438cd17d436359f2c6eee0bb9e55d2a8b (diff) |
Fix log import, fix mapping in datamodules, fix nn modules can be hashed
Diffstat (limited to 'text_recognizer/data/base_dataset.py')
-rw-r--r-- | text_recognizer/data/base_dataset.py | 11 |
1 files changed, 5 insertions, 6 deletions
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index c26f1c9..8640d92 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -1,5 +1,5 @@ """Base PyTorch Dataset class.""" -from typing import Any, Callable, Dict, Sequence, Tuple, Union +from typing import Callable, Dict, Optional, Sequence, Tuple, Union import attr import torch @@ -22,14 +22,13 @@ class BaseDataset(Dataset): data: Union[Sequence, Tensor] = attr.ib() targets: Union[Sequence, Tensor] = attr.ib() - transform: Callable = attr.ib() - target_transform: Callable = attr.ib() + transform: Optional[Callable] = attr.ib(default=None) + target_transform: Optional[Callable] = attr.ib(default=None) def __attrs_pre_init__(self) -> None: super().__init__() def __attrs_post_init__(self) -> None: - # TODO: refactor this if len(self.data) != len(self.targets): raise ValueError("Data and targets must be of equal length.") @@ -37,14 +36,14 @@ class BaseDataset(Dataset): """Return the length of the dataset.""" return len(self.data) - def __getitem__(self, index: int) -> Tuple[Any, Any]: + def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: """Return a datum and its target, after processing by transforms. Args: index (int): Index of a datum in the dataset. Returns: - Tuple[Any, Any]: Datum and target pair. + Tuple[Tensor, Tensor]: Datum and target pair. """ datum, target = self.data[index], self.targets[index] |