diff options
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r-- | src/text_recognizer/datasets/dataset.py | 4 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 3 | ||||
-rw-r--r-- | src/text_recognizer/datasets/iam_lines_dataset.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/datasets/transforms.py | 9 | ||||
-rw-r--r-- | src/text_recognizer/datasets/util.py | 15 |
5 files changed, 27 insertions, 6 deletions
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py index 95063bc..e794605 100644 --- a/src/text_recognizer/datasets/dataset.py +++ b/src/text_recognizer/datasets/dataset.py @@ -22,6 +22,7 @@ class Dataset(data.Dataset): init_token: Optional[str] = None, pad_token: Optional[str] = None, eos_token: Optional[str] = None, + lower: bool = False, ) -> None: """Initialization of Dataset class. @@ -33,6 +34,7 @@ class Dataset(data.Dataset): init_token (Optional[str]): String representing the start of sequence token. Defaults to None. pad_token (Optional[str]): String representing the pad token. Defaults to None. eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. + lower (bool): Only use lower case letters. Defaults to False. Raises: ValueError: If subsample_fraction is not None and outside the range (0, 1). @@ -47,7 +49,7 @@ class Dataset(data.Dataset): self.subsample_fraction = subsample_fraction self._mapper = EmnistMapper( - init_token=init_token, eos_token=eos_token, pad_token=pad_token + init_token=init_token, eos_token=eos_token, pad_token=pad_token, lower=lower ) self._input_shape = self._mapper.input_shape self._output_shape = self._mapper._num_classes diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index eddf341..1992446 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -44,6 +44,7 @@ class EmnistLinesDataset(Dataset): init_token: Optional[str] = None, pad_token: Optional[str] = None, eos_token: Optional[str] = None, + lower: bool = False, ) -> None: """Set attributes and loads the dataset. @@ -60,6 +61,7 @@ class EmnistLinesDataset(Dataset): init_token (Optional[str]): String representing the start of sequence token. Defaults to None. pad_token (Optional[str]): String representing the pad token. Defaults to None. eos_token (Optional[str]): String representing the end of sequence token. Defaults to None. + lower (bool): If True, convert uppercase letters to lowercase. Otherwise, use both upper and lowercase. """ self.pad_token = "_" if pad_token is None else pad_token @@ -72,6 +74,7 @@ class EmnistLinesDataset(Dataset): init_token=init_token, pad_token=self.pad_token, eos_token=eos_token, + lower=lower, ) # Extract dataset information. diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py index 5ae142c..1cb84bd 100644 --- a/src/text_recognizer/datasets/iam_lines_dataset.py +++ b/src/text_recognizer/datasets/iam_lines_dataset.py @@ -35,6 +35,7 @@ class IamLinesDataset(Dataset): init_token: Optional[str] = None, pad_token: Optional[str] = None, eos_token: Optional[str] = None, + lower: bool = False, ) -> None: self.pad_token = "_" if pad_token is None else pad_token @@ -46,6 +47,7 @@ class IamLinesDataset(Dataset): init_token=init_token, pad_token=pad_token, eos_token=eos_token, + lower=lower, ) @property diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py index 016ec80..8956b01 100644 --- a/src/text_recognizer/datasets/transforms.py +++ b/src/text_recognizer/datasets/transforms.py @@ -93,3 +93,12 @@ class Squeeze: def __call__(self, x: Tensor) -> Tensor: """Removes first dim.""" return x.squeeze(0) + + +class ToLower: + """Converts target to lower case.""" + + def __call__(self, target: Tensor) -> Tensor: + """Corrects index value in target tensor.""" + device = target.device + return torch.stack([x - 26 if x > 35 else x for x in target]).to(device) diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index bf5e772..da87756 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -1,17 +1,14 @@ """Util functions for datasets.""" import hashlib -import importlib import json import os from pathlib import Path import string -from typing import Callable, Dict, List, Optional, Type, Union -from urllib.request import urlopen, urlretrieve +from typing import Dict, List, Optional, Union +from urllib.request import urlretrieve -import cv2 from loguru import logger import numpy as np -from PIL import Image import torch from torch import Tensor from torchvision.datasets import EMNIST @@ -50,11 +47,13 @@ class EmnistMapper: pad_token: str, init_token: Optional[str] = None, eos_token: Optional[str] = None, + lower: bool = False, ) -> None: """Loads the emnist essentials file with the mapping and input shape.""" self.init_token = init_token self.pad_token = pad_token self.eos_token = eos_token + self.lower = lower self.essentials = self._load_emnist_essentials() # Load dataset information. @@ -120,6 +119,12 @@ class EmnistMapper: def _augment_emnist_mapping(self) -> None: """Augment the mapping with extra symbols.""" # Extra symbols in IAM dataset + if self.lower: + self._mapping = { + k: str(v) + for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase)) + } + extra_symbols = [ " ", "!", |