summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r--src/text_recognizer/datasets/dataset.py4
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py3
-rw-r--r--src/text_recognizer/datasets/iam_lines_dataset.py2
-rw-r--r--src/text_recognizer/datasets/transforms.py9
-rw-r--r--src/text_recognizer/datasets/util.py15
5 files changed, 27 insertions, 6 deletions
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py
index 95063bc..e794605 100644
--- a/src/text_recognizer/datasets/dataset.py
+++ b/src/text_recognizer/datasets/dataset.py
@@ -22,6 +22,7 @@ class Dataset(data.Dataset):
init_token: Optional[str] = None,
pad_token: Optional[str] = None,
eos_token: Optional[str] = None,
+ lower: bool = False,
) -> None:
"""Initialization of Dataset class.
@@ -33,6 +34,7 @@ class Dataset(data.Dataset):
init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
pad_token (Optional[str]): String representing the pad token. Defaults to None.
eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
+ lower (bool): Only use lower case letters. Defaults to False.
Raises:
ValueError: If subsample_fraction is not None and outside the range (0, 1).
@@ -47,7 +49,7 @@ class Dataset(data.Dataset):
self.subsample_fraction = subsample_fraction
self._mapper = EmnistMapper(
- init_token=init_token, eos_token=eos_token, pad_token=pad_token
+ init_token=init_token, eos_token=eos_token, pad_token=pad_token, lower=lower
)
self._input_shape = self._mapper.input_shape
self._output_shape = self._mapper._num_classes
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index eddf341..1992446 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -44,6 +44,7 @@ class EmnistLinesDataset(Dataset):
init_token: Optional[str] = None,
pad_token: Optional[str] = None,
eos_token: Optional[str] = None,
+ lower: bool = False,
) -> None:
"""Set attributes and loads the dataset.
@@ -60,6 +61,7 @@ class EmnistLinesDataset(Dataset):
init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
pad_token (Optional[str]): String representing the pad token. Defaults to None.
eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
+ lower (bool): If True, convert uppercase letters to lowercase. Otherwise, use both upper and lowercase.
"""
self.pad_token = "_" if pad_token is None else pad_token
@@ -72,6 +74,7 @@ class EmnistLinesDataset(Dataset):
init_token=init_token,
pad_token=self.pad_token,
eos_token=eos_token,
+ lower=lower,
)
# Extract dataset information.
diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py
index 5ae142c..1cb84bd 100644
--- a/src/text_recognizer/datasets/iam_lines_dataset.py
+++ b/src/text_recognizer/datasets/iam_lines_dataset.py
@@ -35,6 +35,7 @@ class IamLinesDataset(Dataset):
init_token: Optional[str] = None,
pad_token: Optional[str] = None,
eos_token: Optional[str] = None,
+ lower: bool = False,
) -> None:
self.pad_token = "_" if pad_token is None else pad_token
@@ -46,6 +47,7 @@ class IamLinesDataset(Dataset):
init_token=init_token,
pad_token=pad_token,
eos_token=eos_token,
+ lower=lower,
)
@property
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
index 016ec80..8956b01 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/src/text_recognizer/datasets/transforms.py
@@ -93,3 +93,12 @@ class Squeeze:
def __call__(self, x: Tensor) -> Tensor:
"""Removes first dim."""
return x.squeeze(0)
+
+
+class ToLower:
+ """Converts target to lower case."""
+
+ def __call__(self, target: Tensor) -> Tensor:
+ """Corrects index value in target tensor."""
+ device = target.device
+ return torch.stack([x - 26 if x > 35 else x for x in target]).to(device)
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index bf5e772..da87756 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -1,17 +1,14 @@
"""Util functions for datasets."""
import hashlib
-import importlib
import json
import os
from pathlib import Path
import string
-from typing import Callable, Dict, List, Optional, Type, Union
-from urllib.request import urlopen, urlretrieve
+from typing import Dict, List, Optional, Union
+from urllib.request import urlretrieve
-import cv2
from loguru import logger
import numpy as np
-from PIL import Image
import torch
from torch import Tensor
from torchvision.datasets import EMNIST
@@ -50,11 +47,13 @@ class EmnistMapper:
pad_token: str,
init_token: Optional[str] = None,
eos_token: Optional[str] = None,
+ lower: bool = False,
) -> None:
"""Loads the emnist essentials file with the mapping and input shape."""
self.init_token = init_token
self.pad_token = pad_token
self.eos_token = eos_token
+ self.lower = lower
self.essentials = self._load_emnist_essentials()
# Load dataset information.
@@ -120,6 +119,12 @@ class EmnistMapper:
def _augment_emnist_mapping(self) -> None:
"""Augment the mapping with extra symbols."""
# Extra symbols in IAM dataset
+ if self.lower:
+ self._mapping = {
+ k: str(v)
+ for k, v in enumerate(list(range(10)) + list(string.ascii_lowercase))
+ }
+
extra_symbols = [
" ",
"!",