diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-15 17:40:44 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-15 17:40:44 +0100 |
commit | 75909723fa2b1f6245d5c5422e4f2e88b8a26052 (patch) | |
tree | e60c37d05c724db011d75adf9313d93839d193ac /src/text_recognizer/datasets | |
parent | cad676fc423efeafde65f03e4815248f2d357011 (diff) |
Able to generate support files for lines datasets.
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r-- | src/text_recognizer/datasets/dataset.py | 39 | ||||
-rw-r--r-- | src/text_recognizer/datasets/iam_lines_dataset.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/datasets/util.py | 19 |
3 files changed, 43 insertions, 17 deletions
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py index 2de7f09..95063bc 100644 --- a/src/text_recognizer/datasets/dataset.py +++ b/src/text_recognizer/datasets/dataset.py @@ -1,11 +1,12 @@ """Abstract dataset class.""" -from typing import Callable, Dict, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union import torch from torch import Tensor from torch.utils import data from torchvision.transforms import ToTensor +import text_recognizer.datasets.transforms as transforms from text_recognizer.datasets.util import EmnistMapper @@ -16,8 +17,8 @@ class Dataset(data.Dataset): self, train: bool, subsample_fraction: float = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, + transform: Optional[List[Dict]] = None, + target_transform: Optional[List[Dict]] = None, init_token: Optional[str] = None, pad_token: Optional[str] = None, eos_token: Optional[str] = None, @@ -27,8 +28,8 @@ class Dataset(data.Dataset): Args: train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False. subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None. - transform (Optional[Callable]): Transform(s) for input data. Defaults to None. - target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None. + transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None. + target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None. init_token (Optional[str]): String representing the start of sequence token. Defaults to None. pad_token (Optional[str]): String representing the pad token. Defaults to None. eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. @@ -53,14 +54,34 @@ class Dataset(data.Dataset): self.num_classes = self.mapper.num_classes # Set transforms. - self.transform = transform if transform is not None else ToTensor() - self.target_transform = ( - target_transform if target_transform is not None else torch.tensor - ) + self.transform = self._configure_transform(transform) + self.target_transform = self._configure_target_transform(target_transform) self._data = None self._targets = None + def _configure_transform(self, transform: List[Dict]) -> transforms.Compose: + transform_list = [] + if transform is not None: + for t in transform: + t_type = t["type"] + t_args = t["args"] or {} + transform_list.append(getattr(transforms, t_type)(**t_args)) + else: + transform_list.append(ToTensor()) + return transforms.Compose(transform_list) + + def _configure_target_transform( + self, target_transform: List[Dict] + ) -> transforms.Compose: + target_transform_list = [torch.tensor] + if target_transform is not None: + for t in target_transform: + t_type = t["type"] + t_args = t["args"] or {} + target_transform_list.append(getattr(transforms, t_type)(**t_args)) + return transforms.Compose(target_transform_list) + @property def data(self) -> Tensor: """The input data.""" diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py index fdd2fe6..5ae142c 100644 --- a/src/text_recognizer/datasets/iam_lines_dataset.py +++ b/src/text_recognizer/datasets/iam_lines_dataset.py @@ -36,6 +36,8 @@ class IamLinesDataset(Dataset): pad_token: Optional[str] = None, eos_token: Optional[str] = None, ) -> None: + self.pad_token = "_" if pad_token is None else pad_token + super().__init__( train=train, subsample_fraction=subsample_fraction, diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index d2df8b5..bf5e772 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -12,7 +12,8 @@ import cv2 from loguru import logger import numpy as np from PIL import Image -from torch.utils.data import DataLoader, Dataset +import torch +from torch import Tensor from torchvision.datasets import EMNIST from tqdm import tqdm @@ -20,7 +21,7 @@ DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data" ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json" -def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None: +def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None: """Extract and saves EMNIST essentials.""" labels = emnsit_dataset.classes labels.sort() @@ -56,21 +57,21 @@ class EmnistMapper: self.eos_token = eos_token self.essentials = self._load_emnist_essentials() - # Load dataset infromation. + # Load dataset information. self._mapping = dict(self.essentials["mapping"]) self._augment_emnist_mapping() self._inverse_mapping = {v: k for k, v in self.mapping.items()} self._num_classes = len(self.mapping) self._input_shape = self.essentials["input_shape"] - def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]: + def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]: """Maps the token to emnist character or character index. If the token is an integer (index), the method will return the Emnist character corresponding to that index. If the token is a str (Emnist character), the method will return the corresponding index for that character. Args: - token (Union[str, int, np.uint8]): Eihter a string or index (integer). + token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer). Returns: Union[str, int]: The mapping result. @@ -79,9 +80,11 @@ class EmnistMapper: KeyError: If the index or string does not exist in the mapping. """ - if (isinstance(token, np.uint8) or isinstance(token, int)) and int( - token - ) in self.mapping: + if ( + (isinstance(token, np.uint8) or isinstance(token, int)) + or torch.is_tensor(token) + and int(token) in self.mapping + ): return self.mapping[int(token)] elif isinstance(token, str) and token in self._inverse_mapping: return self._inverse_mapping[token] |