diff options
Diffstat (limited to 'src/text_recognizer/datasets/iam_lines_dataset.py')
-rw-r--r-- | src/text_recognizer/datasets/iam_lines_dataset.py | 68 |
1 files changed, 21 insertions, 47 deletions
diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py index 477f500..4a74b2b 100644 --- a/src/text_recognizer/datasets/iam_lines_dataset.py +++ b/src/text_recognizer/datasets/iam_lines_dataset.py @@ -5,11 +5,15 @@ import h5py from loguru import logger import torch from torch import Tensor -from torch.utils.data import Dataset from torchvision.transforms import ToTensor -from text_recognizer.datasets.emnist_dataset import DATA_DIRNAME, EmnistMapper -from text_recognizer.datasets.util import compute_sha256, download_url +from text_recognizer.datasets.dataset import Dataset +from text_recognizer.datasets.util import ( + compute_sha256, + DATA_DIRNAME, + download_url, + EmnistMapper, +) PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines" @@ -29,47 +33,26 @@ class IamLinesDataset(Dataset): transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: - self.train = train - self.split = "train" if self.train else "test" - self._mapper = EmnistMapper() - self.num_classes = self.mapper.num_classes - - # Set transforms. - self.transform = transform - if self.transform is None: - self.transform = ToTensor() - - self.target_transform = target_transform - if self.target_transform is None: - self.target_transform = torch.tensor - - self.subsample_fraction = subsample_fraction - self.data = None - self.targets = None - - @property - def mapper(self) -> EmnistMapper: - """Returns the EmnistMapper.""" - return self._mapper - - @property - def mapping(self) -> Dict: - """Return EMNIST mapping from index to character.""" - return self._mapper.mapping + super().__init__( + train=train, + subsample_fraction=subsample_fraction, + transform=transform, + target_transform=target_transform, + ) @property def input_shape(self) -> Tuple: """Input shape of the data.""" - return self.data.shape[1:] + return self.data.shape[1:] if self.data is not None else None @property def output_shape(self) -> Tuple: """Output shape of the data.""" - return self.targets.shape[1:] + (self.num_classes,) - - def __len__(self) -> int: - """Returns the length of the dataset.""" - return len(self.data) + return ( + self.targets.shape[1:] + (self.num_classes,) + if self.targets is not None + else None + ) def load_or_generate_data(self) -> None: """Load or generate dataset data.""" @@ -78,19 +61,10 @@ class IamLinesDataset(Dataset): logger.info("Downloading IAM lines...") download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME) with h5py.File(PROCESSED_DATA_FILENAME, "r") as f: - self.data = f[f"x_{self.split}"][:] - self.targets = f[f"y_{self.split}"][:] + self._data = f[f"x_{self.split}"][:] + self._targets = f[f"y_{self.split}"][:] self._subsample() - def _subsample(self) -> None: - """Only a fraction of the data will be loaded.""" - if self.subsample_fraction is None: - return - - num_samples = int(self.data.shape[0] * self.subsample_fraction) - self.data = self.data[:num_samples] - self.targets = self.targets[:num_samples] - def __repr__(self) -> str: """Print info about the dataset.""" return ( |