From 75801019981492eedf9280cb352eea3d8e99b65f Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 2 Aug 2021 21:13:48 +0200 Subject: Fix log import, fix mapping in datamodules, fix nn modules can be hashed --- text_recognizer/data/base_dataset.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) (limited to 'text_recognizer/data/base_dataset.py') diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index c26f1c9..8640d92 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -1,5 +1,5 @@ """Base PyTorch Dataset class.""" -from typing import Any, Callable, Dict, Sequence, Tuple, Union +from typing import Callable, Dict, Optional, Sequence, Tuple, Union import attr import torch @@ -22,14 +22,13 @@ class BaseDataset(Dataset): data: Union[Sequence, Tensor] = attr.ib() targets: Union[Sequence, Tensor] = attr.ib() - transform: Callable = attr.ib() - target_transform: Callable = attr.ib() + transform: Optional[Callable] = attr.ib(default=None) + target_transform: Optional[Callable] = attr.ib(default=None) def __attrs_pre_init__(self) -> None: super().__init__() def __attrs_post_init__(self) -> None: - # TODO: refactor this if len(self.data) != len(self.targets): raise ValueError("Data and targets must be of equal length.") @@ -37,14 +36,14 @@ class BaseDataset(Dataset): """Return the length of the dataset.""" return len(self.data) - def __getitem__(self, index: int) -> Tuple[Any, Any]: + def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: """Return a datum and its target, after processing by transforms. Args: index (int): Index of a datum in the dataset. Returns: - Tuple[Any, Any]: Datum and target pair. + Tuple[Tensor, Tensor]: Datum and target pair. """ datum, target = self.data[index], self.targets[index] -- cgit v1.2.3-70-g09d2