diff options
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 5 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 35 | ||||
-rw-r--r-- | src/text_recognizer/datasets/transforms.py | 27 |
3 files changed, 48 insertions, 19 deletions
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index a8901d6..9884fdf 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -22,6 +22,7 @@ class EmnistDataset(Dataset): def __init__( self, + pad_token: str = None, train: bool = False, sample_to_balance: bool = False, subsample_fraction: float = None, @@ -32,6 +33,7 @@ class EmnistDataset(Dataset): """Loads the dataset and the mappings. Args: + pad_token (str): The pad token symbol. Defaults to _. train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False. subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. @@ -45,6 +47,7 @@ class EmnistDataset(Dataset): subsample_fraction=subsample_fraction, transform=transform, target_transform=target_transform, + pad_token=pad_token, ) self.sample_to_balance = sample_to_balance @@ -53,6 +56,8 @@ class EmnistDataset(Dataset): if transform is None: self.transform = Compose([Transpose(), ToTensor()]) + self.target_transform = None + self.seed = seed def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index 6091da8..6871492 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -4,6 +4,7 @@ from collections import defaultdict from pathlib import Path from typing import Callable, Dict, List, Optional, Tuple, Union +import click import h5py from loguru import logger import numpy as np @@ -58,13 +59,15 @@ class EmnistLinesDataset(Dataset): eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. """ + self.pad_token = "_" if pad_token is None else pad_token + super().__init__( train=train, transform=transform, target_transform=target_transform, subsample_fraction=subsample_fraction, init_token=init_token, - pad_token=pad_token, + pad_token=self.pad_token, eos_token=eos_token, ) @@ -127,11 +130,7 @@ class EmnistLinesDataset(Dataset): @property def data_filename(self) -> Path: """Path to the h5 file.""" - filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt" - if self.train: - filename = "train_" + filename - else: - filename = "test_" + filename + filename = "train.pt" if self.train else "test.pt" return DATA_DIRNAME / filename def load_or_generate_data(self) -> None: @@ -147,8 +146,8 @@ class EmnistLinesDataset(Dataset): """Loads the dataset from the h5 file.""" logger.debug("EmnistLinesDataset loading data from HDF5...") with h5py.File(self.data_filename, "r") as f: - self._data = f["data"][:] - self._targets = f["targets"][:] + self._data = f["data"][()] + self._targets = f["targets"][()] def _generate_data(self) -> str: """Generates a dataset with the Brown corpus and Emnist characters.""" @@ -157,7 +156,9 @@ class EmnistLinesDataset(Dataset): sentence_generator = SentenceGenerator(self.max_length) # Load emnist dataset. - emnist = EmnistDataset(train=self.train, sample_to_balance=True) + emnist = EmnistDataset( + train=self.train, sample_to_balance=True, pad_token=self.pad_token + ) emnist.load_or_generate_data() samples_by_character = get_samples_by_character( @@ -308,6 +309,18 @@ def convert_strings_to_categorical_labels( return np.array([[mapping[c] for c in label] for label in labels]) +@click.command() +@click.option( + "--max_length", type=int, default=34, help="Number of characters in a sentence." +) +@click.option( + "--min_overlap", type=float, default=0.0, help="Min overlap between characters." +) +@click.option( + "--max_overlap", type=float, default=0.33, help="Max overlap between characters." +) +@click.option("--num_train", type=int, default=10_000, help="Number of train examples.") +@click.option("--num_test", type=int, default=1_000, help="Number of test examples.") def create_datasets( max_length: int = 34, min_overlap: float = 0, @@ -326,3 +339,7 @@ def create_datasets( num_samples=num, ) emnist_lines.load_or_generate_data() + + +if __name__ == "__main__": + create_datasets() diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py index c058972..8deac7f 100644 --- a/src/text_recognizer/datasets/transforms.py +++ b/src/text_recognizer/datasets/transforms.py @@ -3,7 +3,7 @@ import numpy as np from PIL import Image import torch from torch import Tensor -from torchvision.transforms import Compose, ToTensor +from torchvision.transforms import Compose, Resize, ToPILImage, ToTensor from text_recognizer.datasets.util import EmnistMapper @@ -19,28 +19,35 @@ class Transpose: class AddTokens: """Adds start of sequence and end of sequence tokens to target tensor.""" - def __init__(self, init_token: str, pad_token: str, eos_token: str,) -> None: + def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None: self.init_token = init_token self.pad_token = pad_token self.eos_token = eos_token - self.emnist_mapper = EmnistMapper( - init_token=self.init_token, - pad_token=self.pad_token, - eos_token=self.eos_token, - ) + if self.init_token is not None: + self.emnist_mapper = EmnistMapper( + init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + ) + else: + self.emnist_mapper = EmnistMapper( + pad_token=self.pad_token, eos_token=self.eos_token, + ) self.pad_value = self.emnist_mapper(self.pad_token) - self.sos_value = self.emnist_mapper(self.init_token) self.eos_value = self.emnist_mapper(self.eos_token) def __call__(self, target: Tensor) -> Tensor: """Adds a sos token to the begining and a eos token to the end of a target sequence.""" dtype, device = target.dtype, target.device - sos = torch.tensor([self.sos_value], dtype=dtype, device=device) # Find the where padding starts. pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item() target[pad_index] = self.eos_value - target = torch.cat([sos, target], dim=0) + if self.init_token is not None: + self.sos_value = self.emnist_mapper(self.init_token) + sos = torch.tensor([self.sos_value], dtype=dtype, device=device) + target = torch.cat([sos, target], dim=0) + return target |