diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-08 23:14:23 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-08 23:14:23 +0200 |
commit | e1b504bca41a9793ed7e88ef14f2e2cbd85724f2 (patch) | |
tree | 70b482f890c9ad2be104f0bff8f2172e8411a2be /src/text_recognizer/datasets/emnist_lines_dataset.py | |
parent | fe23001b6588e6e6e9e2c5a99b72f3445cf5206f (diff) |
IAM datasets implemented.
Diffstat (limited to 'src/text_recognizer/datasets/emnist_lines_dataset.py')
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 38 |
1 files changed, 22 insertions, 16 deletions
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index b0617f5..656131a 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -9,8 +9,8 @@ from loguru import logger import numpy as np import torch from torch import Tensor -from torch.utils.data import DataLoader, Dataset -from torchvision.transforms import Compose, Normalize, ToTensor +from torch.utils.data import Dataset +from torchvision.transforms import ToTensor from text_recognizer.datasets import ( DATA_DIRNAME, @@ -20,6 +20,7 @@ from text_recognizer.datasets import ( ) from text_recognizer.datasets.sentence_generator import SentenceGenerator from text_recognizer.datasets.util import Transpose +from text_recognizer.networks import sliding_window DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines" @@ -55,7 +56,7 @@ class EmnistLinesDataset(Dataset): self.transform = transform if self.transform is None: - self.transform = Compose([ToTensor()]) + self.transform = ToTensor() self.target_transform = target_transform if self.target_transform is None: @@ -63,14 +64,14 @@ class EmnistLinesDataset(Dataset): # Extract dataset information. self._mapper = EmnistMapper() - self.input_shape = self._mapper.input_shape + self._input_shape = self._mapper.input_shape self.num_classes = self._mapper.num_classes self.max_length = max_length self.min_overlap = min_overlap self.max_overlap = max_overlap self.num_samples = num_samples - self.input_shape = ( + self._input_shape = ( self.input_shape[0], self.input_shape[1] * self.max_length, ) @@ -84,6 +85,11 @@ class EmnistLinesDataset(Dataset): # Load dataset. self._load_or_generate_data() + @property + def input_shape(self) -> Tuple: + """Input shape of the data.""" + return self._input_shape + def __len__(self) -> int: """Returns the length of the dataset.""" return len(self.data) @@ -112,11 +118,6 @@ class EmnistLinesDataset(Dataset): return data, targets - @property - def __name__(self) -> str: - """Returns the name of the dataset.""" - return "EmnistLinesDataset" - def __repr__(self) -> str: """Returns information about the dataset.""" return ( @@ -136,13 +137,18 @@ class EmnistLinesDataset(Dataset): return self._mapper @property + def mapping(self) -> Dict: + """Return EMNIST mapping from index to character.""" + return self._mapper.mapping + + @property def data_filename(self) -> Path: """Path to the h5 file.""" filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt" if self.train: filename = "train_" + filename else: - filename = "val_" + filename + filename = "test_" + filename return DATA_DIRNAME / filename def _load_or_generate_data(self) -> None: @@ -184,7 +190,7 @@ class EmnistLinesDataset(Dataset): ) targets = convert_strings_to_categorical_labels( - targets, self.emnist.inverse_mapping + targets, emnist.inverse_mapping ) f.create_dataset("data", data=data, dtype="u1", compression="lzf") @@ -322,13 +328,13 @@ def create_datasets( min_overlap: float = 0, max_overlap: float = 0.33, num_train: int = 10000, - num_val: int = 1000, + num_test: int = 1000, ) -> None: """Creates a training an validation dataset of Emnist lines.""" emnist_train = EmnistDataset(train=True, sample_to_balance=True) - emnist_val = EmnistDataset(train=False, sample_to_balance=True) - datasets = [emnist_train, emnist_val] - num_samples = [num_train, num_val] + emnist_test = EmnistDataset(train=False, sample_to_balance=True) + datasets = [emnist_train, emnist_test] + num_samples = [num_train, num_test] for num, train, dataset in zip(num_samples, [True, False], datasets): emnist_lines = EmnistLinesDataset( train=train, |