diff options
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r-- | src/text_recognizer/datasets/__init__.py | 4 | ||||
-rw-r--r-- | src/text_recognizer/datasets/dataset.py | 22 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 6 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 51 | ||||
-rw-r--r-- | src/text_recognizer/datasets/iam_lines_dataset.py | 6 | ||||
-rw-r--r-- | src/text_recognizer/datasets/iam_paragraphs_dataset.py | 7 | ||||
-rw-r--r-- | src/text_recognizer/datasets/transforms.py | 40 | ||||
-rw-r--r-- | src/text_recognizer/datasets/util.py | 31 |
8 files changed, 133 insertions, 34 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index a3af9b1..d8372e3 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -1,5 +1,5 @@ """Dataset modules.""" -from .emnist_dataset import EmnistDataset, Transpose +from .emnist_dataset import EmnistDataset from .emnist_lines_dataset import ( construct_image_from_string, EmnistLinesDataset, @@ -8,6 +8,7 @@ from .emnist_lines_dataset import ( from .iam_dataset import IamDataset from .iam_lines_dataset import IamLinesDataset from .iam_paragraphs_dataset import IamParagraphsDataset +from .transforms import AddTokens, Transpose from .util import ( _download_raw_dataset, compute_sha256, @@ -19,6 +20,7 @@ from .util import ( __all__ = [ "_download_raw_dataset", + "AddTokens", "compute_sha256", "construct_image_from_string", "DATA_DIRNAME", diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py index 05520e5..2de7f09 100644 --- a/src/text_recognizer/datasets/dataset.py +++ b/src/text_recognizer/datasets/dataset.py @@ -18,6 +18,9 @@ class Dataset(data.Dataset): subsample_fraction: float = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + init_token: Optional[str] = None, + pad_token: Optional[str] = None, + eos_token: Optional[str] = None, ) -> None: """Initialization of Dataset class. @@ -26,12 +29,14 @@ class Dataset(data.Dataset): subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. transform (Optional[Callable]): Transform(s) for input data. Defaults to None. target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. + init_token (Optional[str]): String representing the start of sequence token. Defaults to None. + pad_token (Optional[str]): String representing the pad token. Defaults to None. + eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. Raises: ValueError: If subsample_fraction is not None and outside the range (0, 1). """ - self.train = train self.split = "train" if self.train else "test" @@ -40,19 +45,18 @@ class Dataset(data.Dataset): raise ValueError("The subsample fraction must be in (0, 1).") self.subsample_fraction = subsample_fraction - self._mapper = EmnistMapper() + self._mapper = EmnistMapper( + init_token=init_token, eos_token=eos_token, pad_token=pad_token + ) self._input_shape = self._mapper.input_shape self._output_shape = self._mapper._num_classes self.num_classes = self.mapper.num_classes # Set transforms. - self.transform = transform - if self.transform is None: - self.transform = ToTensor() - - self.target_transform = target_transform - if self.target_transform is None: - self.target_transform = torch.tensor + self.transform = transform if transform is not None else ToTensor() + self.target_transform = ( + target_transform if target_transform is not None else torch.tensor + ) self._data = None self._targets = None diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index d01dcee..9884fdf 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -22,6 +22,7 @@ class EmnistDataset(Dataset): def __init__( self, + pad_token: str = None, train: bool = False, sample_to_balance: bool = False, subsample_fraction: float = None, @@ -32,6 +33,7 @@ class EmnistDataset(Dataset): """Loads the dataset and the mappings. Args: + pad_token (str): The pad token symbol. Defaults to _. train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. sample_to_balance (bool): Resamples the dataset to make it balanced. Defaults to False. subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None. @@ -45,6 +47,7 @@ class EmnistDataset(Dataset): subsample_fraction=subsample_fraction, transform=transform, target_transform=target_transform, + pad_token=pad_token, ) self.sample_to_balance = sample_to_balance @@ -53,8 +56,7 @@ class EmnistDataset(Dataset): if transform is None: self.transform = Compose([Transpose(), ToTensor()]) - # The EMNIST dataset is already casted to tensors. - self.target_transform = target_transform + self.target_transform = None self.seed = seed diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index 6268a01..6871492 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -4,6 +4,7 @@ from collections import defaultdict from pathlib import Path from typing import Callable, Dict, List, Optional, Tuple, Union +import click import h5py from loguru import logger import numpy as np @@ -37,6 +38,9 @@ class EmnistLinesDataset(Dataset): max_overlap: float = 0.33, num_samples: int = 10000, seed: int = 4711, + init_token: Optional[str] = None, + pad_token: Optional[str] = None, + eos_token: Optional[str] = None, ) -> None: """Set attributes and loads the dataset. @@ -50,13 +54,21 @@ class EmnistLinesDataset(Dataset): max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33. num_samples (int): Number of samples to generate. Defaults to 10000. seed (int): Seed number. Defaults to 4711. + init_token (Optional[str]): String representing the start of sequence token. Defaults to None. + pad_token (Optional[str]): String representing the pad token. Defaults to None. + eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. """ + self.pad_token = "_" if pad_token is None else pad_token + super().__init__( train=train, transform=transform, target_transform=target_transform, subsample_fraction=subsample_fraction, + init_token=init_token, + pad_token=self.pad_token, + eos_token=eos_token, ) # Extract dataset information. @@ -118,11 +130,7 @@ class EmnistLinesDataset(Dataset): @property def data_filename(self) -> Path: """Path to the h5 file.""" - filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt" - if self.train: - filename = "train_" + filename - else: - filename = "test_" + filename + filename = "train.pt" if self.train else "test.pt" return DATA_DIRNAME / filename def load_or_generate_data(self) -> None: @@ -138,8 +146,8 @@ class EmnistLinesDataset(Dataset): """Loads the dataset from the h5 file.""" logger.debug("EmnistLinesDataset loading data from HDF5...") with h5py.File(self.data_filename, "r") as f: - self._data = f["data"][:] - self._targets = f["targets"][:] + self._data = f["data"][()] + self._targets = f["targets"][()] def _generate_data(self) -> str: """Generates a dataset with the Brown corpus and Emnist characters.""" @@ -148,7 +156,10 @@ class EmnistLinesDataset(Dataset): sentence_generator = SentenceGenerator(self.max_length) # Load emnist dataset. - emnist = EmnistDataset(train=self.train, sample_to_balance=True) + emnist = EmnistDataset( + train=self.train, sample_to_balance=True, pad_token=self.pad_token + ) + emnist.load_or_generate_data() samples_by_character = get_samples_by_character( emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping, @@ -298,6 +309,18 @@ def convert_strings_to_categorical_labels( return np.array([[mapping[c] for c in label] for label in labels]) +@click.command() +@click.option( + "--max_length", type=int, default=34, help="Number of characters in a sentence." +) +@click.option( + "--min_overlap", type=float, default=0.0, help="Min overlap between characters." +) +@click.option( + "--max_overlap", type=float, default=0.33, help="Max overlap between characters." +) +@click.option("--num_train", type=int, default=10_000, help="Number of train examples.") +@click.option("--num_test", type=int, default=1_000, help="Number of test examples.") def create_datasets( max_length: int = 34, min_overlap: float = 0, @@ -306,17 +329,17 @@ def create_datasets( num_test: int = 1000, ) -> None: """Creates a training an validation dataset of Emnist lines.""" - emnist_train = EmnistDataset(train=True, sample_to_balance=True) - emnist_test = EmnistDataset(train=False, sample_to_balance=True) - datasets = [emnist_train, emnist_test] num_samples = [num_train, num_test] - for num, train, dataset in zip(num_samples, [True, False], datasets): + for num, train in zip(num_samples, [True, False]): emnist_lines = EmnistLinesDataset( train=train, - emnist=dataset, max_length=max_length, min_overlap=min_overlap, max_overlap=max_overlap, num_samples=num, ) - emnist_lines._load_or_generate_data() + emnist_lines.load_or_generate_data() + + +if __name__ == "__main__": + create_datasets() diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py index 4a74b2b..fdd2fe6 100644 --- a/src/text_recognizer/datasets/iam_lines_dataset.py +++ b/src/text_recognizer/datasets/iam_lines_dataset.py @@ -32,12 +32,18 @@ class IamLinesDataset(Dataset): subsample_fraction: float = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + init_token: Optional[str] = None, + pad_token: Optional[str] = None, + eos_token: Optional[str] = None, ) -> None: super().__init__( train=train, subsample_fraction=subsample_fraction, transform=transform, target_transform=target_transform, + init_token=init_token, + pad_token=pad_token, + eos_token=eos_token, ) @property diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py index 4b34bd1..c1e8fe2 100644 --- a/src/text_recognizer/datasets/iam_paragraphs_dataset.py +++ b/src/text_recognizer/datasets/iam_paragraphs_dataset.py @@ -266,11 +266,16 @@ def _load_iam_paragraphs() -> None: @click.option( "--subsample_fraction", type=float, - default=0.0, + default=None, help="The subsampling factor of the dataset.", ) def main(subsample_fraction: float) -> None: """Load dataset and print info.""" + logger.info("Creating train set...") + dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction) + dataset.load_or_generate_data() + print(dataset) + logger.info("Creating test set...") dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction) dataset.load_or_generate_data() print(dataset) diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py index 17231a8..8deac7f 100644 --- a/src/text_recognizer/datasets/transforms.py +++ b/src/text_recognizer/datasets/transforms.py @@ -3,6 +3,9 @@ import numpy as np from PIL import Image import torch from torch import Tensor +from torchvision.transforms import Compose, Resize, ToPILImage, ToTensor + +from text_recognizer.datasets.util import EmnistMapper class Transpose: @@ -11,3 +14,40 @@ class Transpose: def __call__(self, image: Image) -> np.ndarray: """Swaps axis.""" return np.array(image).swapaxes(0, 1) + + +class AddTokens: + """Adds start of sequence and end of sequence tokens to target tensor.""" + + def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None: + self.init_token = init_token + self.pad_token = pad_token + self.eos_token = eos_token + if self.init_token is not None: + self.emnist_mapper = EmnistMapper( + init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + ) + else: + self.emnist_mapper = EmnistMapper( + pad_token=self.pad_token, eos_token=self.eos_token, + ) + self.pad_value = self.emnist_mapper(self.pad_token) + self.eos_value = self.emnist_mapper(self.eos_token) + + def __call__(self, target: Tensor) -> Tensor: + """Adds a sos token to the begining and a eos token to the end of a target sequence.""" + dtype, device = target.dtype, target.device + + # Find the where padding starts. + pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item() + + target[pad_index] = self.eos_value + + if self.init_token is not None: + self.sos_value = self.emnist_mapper(self.init_token) + sos = torch.tensor([self.sos_value], dtype=dtype, device=device) + target = torch.cat([sos, target], dim=0) + + return target diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index 73968a1..d2df8b5 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -4,6 +4,7 @@ import importlib import json import os from pathlib import Path +import string from typing import Callable, Dict, List, Optional, Type, Union from urllib.request import urlopen, urlretrieve @@ -26,7 +27,7 @@ def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None: mapping = [(i, str(label)) for i, label in enumerate(labels)] essentials = { "mapping": mapping, - "input_shape": tuple(emnsit_dataset[0][0].shape[:]), + "input_shape": tuple(np.array(emnsit_dataset[0][0]).shape[:]), } logger.info("Saving emnist essentials...") with open(ESSENTIALS_FILENAME, "w") as f: @@ -43,11 +44,21 @@ def download_emnist() -> None: class EmnistMapper: """Mapper between network output to Emnist character.""" - def __init__(self) -> None: + def __init__( + self, + pad_token: str, + init_token: Optional[str] = None, + eos_token: Optional[str] = None, + ) -> None: """Loads the emnist essentials file with the mapping and input shape.""" + self.init_token = init_token + self.pad_token = pad_token + self.eos_token = eos_token + self.essentials = self._load_emnist_essentials() # Load dataset infromation. - self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"])) + self._mapping = dict(self.essentials["mapping"]) + self._augment_emnist_mapping() self._inverse_mapping = {v: k for k, v in self.mapping.items()} self._num_classes = len(self.mapping) self._input_shape = self.essentials["input_shape"] @@ -103,7 +114,7 @@ class EmnistMapper: essentials = json.load(f) return essentials - def _augment_emnist_mapping(self, mapping: Dict) -> Dict: + def _augment_emnist_mapping(self) -> None: """Augment the mapping with extra symbols.""" # Extra symbols in IAM dataset extra_symbols = [ @@ -127,14 +138,20 @@ class EmnistMapper: ] # padding symbol, and acts as blank symbol as well. - extra_symbols.append("_") + extra_symbols.append(self.pad_token) + + if self.init_token is not None: + extra_symbols.append(self.init_token) + + if self.eos_token is not None: + extra_symbols.append(self.eos_token) - max_key = max(mapping.keys()) + max_key = max(self.mapping.keys()) extra_mapping = {} for i, symbol in enumerate(extra_symbols): extra_mapping[max_key + 1 + i] = symbol - return {**mapping, **extra_mapping} + self._mapping = {**self.mapping, **extra_mapping} def compute_sha256(filename: Union[Path, str]) -> str: |