diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-09 23:31:31 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-09 23:31:31 +0200 |
commit | 2b63fd952bdc9c7c72edd501cbcdbf3231e98f00 (patch) | |
tree | 1c0e0898cb8b66faff9e5d410aa1f82d13542f68 /src/text_recognizer/datasets/emnist_lines_dataset.py | |
parent | e1b504bca41a9793ed7e88ef14f2e2cbd85724f2 (diff) |
Created an abstract Dataset class for common methods.
Diffstat (limited to 'src/text_recognizer/datasets/emnist_lines_dataset.py')
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 56 |
1 files changed, 13 insertions, 43 deletions
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index 656131a..8fa77cd 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -9,17 +9,16 @@ from loguru import logger import numpy as np import torch from torch import Tensor -from torch.utils.data import Dataset from torchvision.transforms import ToTensor -from text_recognizer.datasets import ( +from text_recognizer.datasets.dataset import Dataset +from text_recognizer.datasets.emnist_dataset import EmnistDataset, Transpose +from text_recognizer.datasets.sentence_generator import SentenceGenerator +from text_recognizer.datasets.util import ( DATA_DIRNAME, - EmnistDataset, EmnistMapper, ESSENTIALS_FILENAME, ) -from text_recognizer.datasets.sentence_generator import SentenceGenerator -from text_recognizer.datasets.util import Transpose from text_recognizer.networks import sliding_window DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines" @@ -52,18 +51,11 @@ class EmnistLinesDataset(Dataset): seed (int): Seed number. Defaults to 4711. """ - self.train = train - - self.transform = transform - if self.transform is None: - self.transform = ToTensor() - - self.target_transform = target_transform - if self.target_transform is None: - self.target_transform = torch.tensor + super().__init__( + train=train, transform=transform, target_transform=target_transform, + ) # Extract dataset information. - self._mapper = EmnistMapper() self._input_shape = self._mapper.input_shape self.num_classes = self._mapper.num_classes @@ -75,24 +67,12 @@ class EmnistLinesDataset(Dataset): self.input_shape[0], self.input_shape[1] * self.max_length, ) - self.output_shape = (self.max_length, self.num_classes) + self._output_shape = (self.max_length, self.num_classes) self.seed = seed # Placeholders for the dataset. - self.data = None - self.target = None - - # Load dataset. - self._load_or_generate_data() - - @property - def input_shape(self) -> Tuple: - """Input shape of the data.""" - return self._input_shape - - def __len__(self) -> int: - """Returns the length of the dataset.""" - return len(self.data) + self._data = None + self._target = None def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: """Fetches data, target pair of the dataset for a given and index or indices. @@ -132,16 +112,6 @@ class EmnistLinesDataset(Dataset): ) @property - def mapper(self) -> EmnistMapper: - """Returns the EmnistMapper.""" - return self._mapper - - @property - def mapping(self) -> Dict: - """Return EMNIST mapping from index to character.""" - return self._mapper.mapping - - @property def data_filename(self) -> Path: """Path to the h5 file.""" filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt" @@ -151,7 +121,7 @@ class EmnistLinesDataset(Dataset): filename = "test_" + filename return DATA_DIRNAME / filename - def _load_or_generate_data(self) -> None: + def load_or_generate_data(self) -> None: """Loads the dataset, if it does not exist a new dataset is generated before loading it.""" np.random.seed(self.seed) @@ -163,8 +133,8 @@ class EmnistLinesDataset(Dataset): """Loads the dataset from the h5 file.""" logger.debug("EmnistLinesDataset loading data from HDF5...") with h5py.File(self.data_filename, "r") as f: - self.data = f["data"][:] - self.targets = f["targets"][:] + self._data = f["data"][:] + self._targets = f["targets"][:] def _generate_data(self) -> str: """Generates a dataset with the Brown corpus and Emnist characters.""" |