From 2b63fd952bdc9c7c72edd501cbcdbf3231e98f00 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Wed, 9 Sep 2020 23:31:31 +0200 Subject: Created an abstract Dataset class for common methods. --- .../datasets/emnist_lines_dataset.py | 56 +++++----------------- 1 file changed, 13 insertions(+), 43 deletions(-) (limited to 'src/text_recognizer/datasets/emnist_lines_dataset.py') diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index 656131a..8fa77cd 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -9,17 +9,16 @@ from loguru import logger import numpy as np import torch from torch import Tensor -from torch.utils.data import Dataset from torchvision.transforms import ToTensor -from text_recognizer.datasets import ( +from text_recognizer.datasets.dataset import Dataset +from text_recognizer.datasets.emnist_dataset import EmnistDataset, Transpose +from text_recognizer.datasets.sentence_generator import SentenceGenerator +from text_recognizer.datasets.util import ( DATA_DIRNAME, - EmnistDataset, EmnistMapper, ESSENTIALS_FILENAME, ) -from text_recognizer.datasets.sentence_generator import SentenceGenerator -from text_recognizer.datasets.util import Transpose from text_recognizer.networks import sliding_window DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines" @@ -52,18 +51,11 @@ class EmnistLinesDataset(Dataset): seed (int): Seed number. Defaults to 4711. """ - self.train = train - - self.transform = transform - if self.transform is None: - self.transform = ToTensor() - - self.target_transform = target_transform - if self.target_transform is None: - self.target_transform = torch.tensor + super().__init__( + train=train, transform=transform, target_transform=target_transform, + ) # Extract dataset information. - self._mapper = EmnistMapper() self._input_shape = self._mapper.input_shape self.num_classes = self._mapper.num_classes @@ -75,24 +67,12 @@ class EmnistLinesDataset(Dataset): self.input_shape[0], self.input_shape[1] * self.max_length, ) - self.output_shape = (self.max_length, self.num_classes) + self._output_shape = (self.max_length, self.num_classes) self.seed = seed # Placeholders for the dataset. - self.data = None - self.target = None - - # Load dataset. - self._load_or_generate_data() - - @property - def input_shape(self) -> Tuple: - """Input shape of the data.""" - return self._input_shape - - def __len__(self) -> int: - """Returns the length of the dataset.""" - return len(self.data) + self._data = None + self._target = None def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: """Fetches data, target pair of the dataset for a given and index or indices. @@ -131,16 +111,6 @@ class EmnistLinesDataset(Dataset): f"Tagets: {self.targets.shape}\n" ) - @property - def mapper(self) -> EmnistMapper: - """Returns the EmnistMapper.""" - return self._mapper - - @property - def mapping(self) -> Dict: - """Return EMNIST mapping from index to character.""" - return self._mapper.mapping - @property def data_filename(self) -> Path: """Path to the h5 file.""" @@ -151,7 +121,7 @@ class EmnistLinesDataset(Dataset): filename = "test_" + filename return DATA_DIRNAME / filename - def _load_or_generate_data(self) -> None: + def load_or_generate_data(self) -> None: """Loads the dataset, if it does not exist a new dataset is generated before loading it.""" np.random.seed(self.seed) @@ -163,8 +133,8 @@ class EmnistLinesDataset(Dataset): """Loads the dataset from the h5 file.""" logger.debug("EmnistLinesDataset loading data from HDF5...") with h5py.File(self.data_filename, "r") as f: - self.data = f["data"][:] - self.targets = f["targets"][:] + self._data = f["data"][:] + self._targets = f["targets"][:] def _generate_data(self) -> str: """Generates a dataset with the Brown corpus and Emnist characters.""" -- cgit v1.2.3-70-g09d2