diff options
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r-- | src/text_recognizer/datasets/__init__.py | 4 | ||||
-rw-r--r-- | src/text_recognizer/datasets/dataset.py | 22 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 3 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 9 | ||||
-rw-r--r-- | src/text_recognizer/datasets/iam_lines_dataset.py | 6 | ||||
-rw-r--r-- | src/text_recognizer/datasets/transforms.py | 33 | ||||
-rw-r--r-- | src/text_recognizer/datasets/util.py | 29 |
7 files changed, 87 insertions, 19 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index a3af9b1..d8372e3 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -1,5 +1,5 @@ """Dataset modules.""" -from .emnist_dataset import EmnistDataset, Transpose +from .emnist_dataset import EmnistDataset from .emnist_lines_dataset import ( construct_image_from_string, EmnistLinesDataset, @@ -8,6 +8,7 @@ from .emnist_lines_dataset import ( from .iam_dataset import IamDataset from .iam_lines_dataset import IamLinesDataset from .iam_paragraphs_dataset import IamParagraphsDataset +from .transforms import AddTokens, Transpose from .util import ( _download_raw_dataset, compute_sha256, @@ -19,6 +20,7 @@ from .util import ( __all__ = [ "_download_raw_dataset", + "AddTokens", "compute_sha256", "construct_image_from_string", "DATA_DIRNAME", diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py index 05520e5..2de7f09 100644 --- a/src/text_recognizer/datasets/dataset.py +++ b/src/text_recognizer/datasets/dataset.py @@ -18,6 +18,9 @@ class Dataset(data.Dataset): subsample_fraction: float = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + init_token: Optional[str] = None, + pad_token: Optional[str] = None, + eos_token: Optional[str] = None, ) -> None: """Initialization of Dataset class. @@ -26,12 +29,14 @@ class Dataset(data.Dataset): subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. transform (Optional[Callable]): Transform(s) for input data. Defaults to None. target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. + init_token (Optional[str]): String representing the start of sequence token. Defaults to None. + pad_token (Optional[str]): String representing the pad token. Defaults to None. + eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. Raises: ValueError: If subsample_fraction is not None and outside the range (0, 1). """ - self.train = train self.split = "train" if self.train else "test" @@ -40,19 +45,18 @@ class Dataset(data.Dataset): raise ValueError("The subsample fraction must be in (0, 1).") self.subsample_fraction = subsample_fraction - self._mapper = EmnistMapper() + self._mapper = EmnistMapper( + init_token=init_token, eos_token=eos_token, pad_token=pad_token + ) self._input_shape = self._mapper.input_shape self._output_shape = self._mapper._num_classes self.num_classes = self.mapper.num_classes # Set transforms. - self.transform = transform - if self.transform is None: - self.transform = ToTensor() - - self.target_transform = target_transform - if self.target_transform is None: - self.target_transform = torch.tensor + self.transform = transform if transform is not None else ToTensor() + self.target_transform = ( + target_transform if target_transform is not None else torch.tensor + ) self._data = None self._targets = None diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index d01dcee..a8901d6 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -53,9 +53,6 @@ class EmnistDataset(Dataset): if transform is None: self.transform = Compose([Transpose(), ToTensor()]) - # The EMNIST dataset is already casted to tensors. - self.target_transform = target_transform - self.seed = seed def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]: diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index beb5343..6091da8 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -37,6 +37,9 @@ class EmnistLinesDataset(Dataset): max_overlap: float = 0.33, num_samples: int = 10000, seed: int = 4711, + init_token: Optional[str] = None, + pad_token: Optional[str] = None, + eos_token: Optional[str] = None, ) -> None: """Set attributes and loads the dataset. @@ -50,6 +53,9 @@ class EmnistLinesDataset(Dataset): max_overlap (float): The maximum overlap between concatenated images. Defaults to 0.33. num_samples (int): Number of samples to generate. Defaults to 10000. seed (int): Seed number. Defaults to 4711. + init_token (Optional[str]): String representing the start of sequence token. Defaults to None. + pad_token (Optional[str]): String representing the pad token. Defaults to None. + eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. """ super().__init__( @@ -57,6 +63,9 @@ class EmnistLinesDataset(Dataset): transform=transform, target_transform=target_transform, subsample_fraction=subsample_fraction, + init_token=init_token, + pad_token=pad_token, + eos_token=eos_token, ) # Extract dataset information. diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py index 4a74b2b..fdd2fe6 100644 --- a/src/text_recognizer/datasets/iam_lines_dataset.py +++ b/src/text_recognizer/datasets/iam_lines_dataset.py @@ -32,12 +32,18 @@ class IamLinesDataset(Dataset): subsample_fraction: float = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, + init_token: Optional[str] = None, + pad_token: Optional[str] = None, + eos_token: Optional[str] = None, ) -> None: super().__init__( train=train, subsample_fraction=subsample_fraction, transform=transform, target_transform=target_transform, + init_token=init_token, + pad_token=pad_token, + eos_token=eos_token, ) @property diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py index 17231a8..c058972 100644 --- a/src/text_recognizer/datasets/transforms.py +++ b/src/text_recognizer/datasets/transforms.py @@ -3,6 +3,9 @@ import numpy as np from PIL import Image import torch from torch import Tensor +from torchvision.transforms import Compose, ToTensor + +from text_recognizer.datasets.util import EmnistMapper class Transpose: @@ -11,3 +14,33 @@ class Transpose: def __call__(self, image: Image) -> np.ndarray: """Swaps axis.""" return np.array(image).swapaxes(0, 1) + + +class AddTokens: + """Adds start of sequence and end of sequence tokens to target tensor.""" + + def __init__(self, init_token: str, pad_token: str, eos_token: str,) -> None: + self.init_token = init_token + self.pad_token = pad_token + self.eos_token = eos_token + self.emnist_mapper = EmnistMapper( + init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + ) + self.pad_value = self.emnist_mapper(self.pad_token) + self.sos_value = self.emnist_mapper(self.init_token) + self.eos_value = self.emnist_mapper(self.eos_token) + + def __call__(self, target: Tensor) -> Tensor: + """Adds a sos token to the begining and a eos token to the end of a target sequence.""" + dtype, device = target.dtype, target.device + sos = torch.tensor([self.sos_value], dtype=dtype, device=device) + + # Find the where padding starts. + pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item() + + target[pad_index] = self.eos_value + + target = torch.cat([sos, target], dim=0) + return target diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index 125f05a..d2df8b5 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -4,6 +4,7 @@ import importlib import json import os from pathlib import Path +import string from typing import Callable, Dict, List, Optional, Type, Union from urllib.request import urlopen, urlretrieve @@ -43,11 +44,21 @@ def download_emnist() -> None: class EmnistMapper: """Mapper between network output to Emnist character.""" - def __init__(self) -> None: + def __init__( + self, + pad_token: str, + init_token: Optional[str] = None, + eos_token: Optional[str] = None, + ) -> None: """Loads the emnist essentials file with the mapping and input shape.""" + self.init_token = init_token + self.pad_token = pad_token + self.eos_token = eos_token + self.essentials = self._load_emnist_essentials() # Load dataset infromation. - self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"])) + self._mapping = dict(self.essentials["mapping"]) + self._augment_emnist_mapping() self._inverse_mapping = {v: k for k, v in self.mapping.items()} self._num_classes = len(self.mapping) self._input_shape = self.essentials["input_shape"] @@ -103,7 +114,7 @@ class EmnistMapper: essentials = json.load(f) return essentials - def _augment_emnist_mapping(self, mapping: Dict) -> Dict: + def _augment_emnist_mapping(self) -> None: """Augment the mapping with extra symbols.""" # Extra symbols in IAM dataset extra_symbols = [ @@ -127,14 +138,20 @@ class EmnistMapper: ] # padding symbol, and acts as blank symbol as well. - extra_symbols.append("_") + extra_symbols.append(self.pad_token) + + if self.init_token is not None: + extra_symbols.append(self.init_token) + + if self.eos_token is not None: + extra_symbols.append(self.eos_token) - max_key = max(mapping.keys()) + max_key = max(self.mapping.keys()) extra_mapping = {} for i, symbol in enumerate(extra_symbols): extra_mapping[max_key + 1 + i] = symbol - return {**mapping, **extra_mapping} + self._mapping = {**self.mapping, **extra_mapping} def compute_sha256(filename: Union[Path, str]) -> str: |