summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r--src/text_recognizer/datasets/dataset.py39
-rw-r--r--src/text_recognizer/datasets/iam_lines_dataset.py2
-rw-r--r--src/text_recognizer/datasets/util.py19
3 files changed, 43 insertions, 17 deletions
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py
index 2de7f09..95063bc 100644
--- a/src/text_recognizer/datasets/dataset.py
+++ b/src/text_recognizer/datasets/dataset.py
@@ -1,11 +1,12 @@
"""Abstract dataset class."""
-from typing import Callable, Dict, Optional, Tuple, Union
+from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
from torch import Tensor
from torch.utils import data
from torchvision.transforms import ToTensor
+import text_recognizer.datasets.transforms as transforms
from text_recognizer.datasets.util import EmnistMapper
@@ -16,8 +17,8 @@ class Dataset(data.Dataset):
self,
train: bool,
subsample_fraction: float = None,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
+ transform: Optional[List[Dict]] = None,
+ target_transform: Optional[List[Dict]] = None,
init_token: Optional[str] = None,
pad_token: Optional[str] = None,
eos_token: Optional[str] = None,
@@ -27,8 +28,8 @@ class Dataset(data.Dataset):
Args:
train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
subsample_fraction (float): The fraction of the dataset to use for training. Defaults to None.
- transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
- target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
+ transform (Optional[List[Dict]]): List of Transform types and args for input data. Defaults to None.
+ target_transform (Optional[List[Dict]]): List of Transform types and args for output data. Defaults to None.
init_token (Optional[str]): String representing the start of sequence token. Defaults to None.
pad_token (Optional[str]): String representing the pad token. Defaults to None.
eos_token (Optional[str]): String representing the end of sequence token. Defaults to None.
@@ -53,14 +54,34 @@ class Dataset(data.Dataset):
self.num_classes = self.mapper.num_classes
# Set transforms.
- self.transform = transform if transform is not None else ToTensor()
- self.target_transform = (
- target_transform if target_transform is not None else torch.tensor
- )
+ self.transform = self._configure_transform(transform)
+ self.target_transform = self._configure_target_transform(target_transform)
self._data = None
self._targets = None
+ def _configure_transform(self, transform: List[Dict]) -> transforms.Compose:
+ transform_list = []
+ if transform is not None:
+ for t in transform:
+ t_type = t["type"]
+ t_args = t["args"] or {}
+ transform_list.append(getattr(transforms, t_type)(**t_args))
+ else:
+ transform_list.append(ToTensor())
+ return transforms.Compose(transform_list)
+
+ def _configure_target_transform(
+ self, target_transform: List[Dict]
+ ) -> transforms.Compose:
+ target_transform_list = [torch.tensor]
+ if target_transform is not None:
+ for t in target_transform:
+ t_type = t["type"]
+ t_args = t["args"] or {}
+ target_transform_list.append(getattr(transforms, t_type)(**t_args))
+ return transforms.Compose(target_transform_list)
+
@property
def data(self) -> Tensor:
"""The input data."""
diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py
index fdd2fe6..5ae142c 100644
--- a/src/text_recognizer/datasets/iam_lines_dataset.py
+++ b/src/text_recognizer/datasets/iam_lines_dataset.py
@@ -36,6 +36,8 @@ class IamLinesDataset(Dataset):
pad_token: Optional[str] = None,
eos_token: Optional[str] = None,
) -> None:
+ self.pad_token = "_" if pad_token is None else pad_token
+
super().__init__(
train=train,
subsample_fraction=subsample_fraction,
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index d2df8b5..bf5e772 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -12,7 +12,8 @@ import cv2
from loguru import logger
import numpy as np
from PIL import Image
-from torch.utils.data import DataLoader, Dataset
+import torch
+from torch import Tensor
from torchvision.datasets import EMNIST
from tqdm import tqdm
@@ -20,7 +21,7 @@ DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json"
-def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None:
+def save_emnist_essentials(emnsit_dataset: EMNIST = EMNIST) -> None:
"""Extract and saves EMNIST essentials."""
labels = emnsit_dataset.classes
labels.sort()
@@ -56,21 +57,21 @@ class EmnistMapper:
self.eos_token = eos_token
self.essentials = self._load_emnist_essentials()
- # Load dataset infromation.
+ # Load dataset information.
self._mapping = dict(self.essentials["mapping"])
self._augment_emnist_mapping()
self._inverse_mapping = {v: k for k, v in self.mapping.items()}
self._num_classes = len(self.mapping)
self._input_shape = self.essentials["input_shape"]
- def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]:
+ def __call__(self, token: Union[str, int, np.uint8, Tensor]) -> Union[str, int]:
"""Maps the token to emnist character or character index.
If the token is an integer (index), the method will return the Emnist character corresponding to that index.
If the token is a str (Emnist character), the method will return the corresponding index for that character.
Args:
- token (Union[str, int, np.uint8]): Eihter a string or index (integer).
+ token (Union[str, int, np.uint8, Tensor]): Either a string or index (integer).
Returns:
Union[str, int]: The mapping result.
@@ -79,9 +80,11 @@ class EmnistMapper:
KeyError: If the index or string does not exist in the mapping.
"""
- if (isinstance(token, np.uint8) or isinstance(token, int)) and int(
- token
- ) in self.mapping:
+ if (
+ (isinstance(token, np.uint8) or isinstance(token, int))
+ or torch.is_tensor(token)
+ and int(token) in self.mapping
+ ):
return self.mapping[int(token)]
elif isinstance(token, str) and token in self._inverse_mapping:
return self._inverse_mapping[token]