summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r--src/text_recognizer/datasets/__init__.py13
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py275
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py129
-rw-r--r--src/text_recognizer/datasets/util.py60
4 files changed, 201 insertions, 276 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py
index 1b4cc59..05f74f6 100644
--- a/src/text_recognizer/datasets/__init__.py
+++ b/src/text_recognizer/datasets/__init__.py
@@ -1,29 +1,24 @@
"""Dataset modules."""
from .emnist_dataset import (
- _augment_emnist_mapping,
- _load_emnist_essentials,
DATA_DIRNAME,
- EmnistDataLoaders,
EmnistDataset,
+ EmnistMapper,
ESSENTIALS_FILENAME,
)
from .emnist_lines_dataset import (
construct_image_from_string,
- EmnistLinesDataLoaders,
EmnistLinesDataset,
get_samples_by_character,
)
-from .util import Transpose
+from .util import fetch_data_loaders, Transpose
__all__ = [
- "_augment_emnist_mapping",
- "_load_emnist_essentials",
"construct_image_from_string",
"DATA_DIRNAME",
"EmnistDataset",
- "EmnistDataLoaders",
- "EmnistLinesDataLoaders",
+ "EmnistMapper",
"EmnistLinesDataset",
+ "fetch_data_loaders",
"get_samples_by_character",
"Transpose",
]
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index f3d65ee..96f84e5 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -39,45 +39,101 @@ def download_emnist() -> None:
save_emnist_essentials(dataset)
-def _load_emnist_essentials() -> Dict:
- """Load the EMNIST mapping."""
- with open(str(ESSENTIALS_FILENAME)) as f:
- essentials = json.load(f)
- return essentials
-
-
-def _augment_emnist_mapping(mapping: Dict) -> Dict:
- """Augment the mapping with extra symbols."""
- # Extra symbols in IAM dataset
- extra_symbols = [
- " ",
- "!",
- '"',
- "#",
- "&",
- "'",
- "(",
- ")",
- "*",
- "+",
- ",",
- "-",
- ".",
- "/",
- ":",
- ";",
- "?",
- ]
-
- # padding symbol
- extra_symbols.append("_")
-
- max_key = max(mapping.keys())
- extra_mapping = {}
- for i, symbol in enumerate(extra_symbols):
- extra_mapping[max_key + 1 + i] = symbol
-
- return {**mapping, **extra_mapping}
+class EmnistMapper:
+ """Mapper between network output to Emnist character."""
+
+ def __init__(self) -> None:
+ """Loads the emnist essentials file with the mapping and input shape."""
+ self.essentials = self._load_emnist_essentials()
+ # Load dataset infromation.
+ self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"]))
+ self._inverse_mapping = {v: k for k, v in self.mapping.items()}
+ self._num_classes = len(self.mapping)
+ self._input_shape = self.essentials["input_shape"]
+
+ def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]:
+ """Maps the token to emnist character or character index.
+
+ If the token is an integer (index), the method will return the Emnist character corresponding to that index.
+ If the token is a str (Emnist character), the method will return the corresponding index for that character.
+
+ Args:
+ token (Union[str, int, np.uint8]): Eihter a string or index (integer).
+
+ Returns:
+ Union[str, int]: The mapping result.
+
+ Raises:
+ KeyError: If the index or string does not exist in the mapping.
+
+ """
+ if (isinstance(token, np.uint8) or isinstance(token, int)) and int(
+ token
+ ) in self.mapping:
+ return self.mapping[int(token)]
+ elif isinstance(token, str) and token in self._inverse_mapping:
+ return self._inverse_mapping[token]
+ else:
+ raise KeyError(f"Token {token} does not exist in the mappings.")
+
+ @property
+ def mapping(self) -> Dict:
+ """Returns the mapping between index and character."""
+ return self._mapping
+
+ @property
+ def inverse_mapping(self) -> Dict:
+ """Returns the mapping between character and index."""
+ return self._inverse_mapping
+
+ @property
+ def num_classes(self) -> int:
+ """Returns the number of classes in the dataset."""
+ return self._num_classes
+
+ @property
+ def input_shape(self) -> List[int]:
+ """Returns the input shape of the Emnist characters."""
+ return self._input_shape
+
+ def _load_emnist_essentials(self) -> Dict:
+ """Load the EMNIST mapping."""
+ with open(str(ESSENTIALS_FILENAME)) as f:
+ essentials = json.load(f)
+ return essentials
+
+ def _augment_emnist_mapping(self, mapping: Dict) -> Dict:
+ """Augment the mapping with extra symbols."""
+ # Extra symbols in IAM dataset
+ extra_symbols = [
+ " ",
+ "!",
+ '"',
+ "#",
+ "&",
+ "'",
+ "(",
+ ")",
+ "*",
+ "+",
+ ",",
+ "-",
+ ".",
+ "/",
+ ":",
+ ";",
+ "?",
+ ]
+
+ # padding symbol
+ extra_symbols.append("_")
+
+ max_key = max(mapping.keys())
+ extra_mapping = {}
+ for i, symbol in enumerate(extra_symbols):
+ extra_mapping[max_key + 1 + i] = symbol
+
+ return {**mapping, **extra_mapping}
class EmnistDataset(Dataset):
@@ -110,10 +166,12 @@ class EmnistDataset(Dataset):
self.train = train
self.sample_to_balance = sample_to_balance
+
if subsample_fraction is not None:
if not 0.0 < subsample_fraction < 1.0:
raise ValueError("The subsample fraction must be in (0, 1).")
self.subsample_fraction = subsample_fraction
+
self.transform = transform
if self.transform is None:
self.transform = Compose([Transpose(), ToTensor()])
@@ -121,17 +179,22 @@ class EmnistDataset(Dataset):
self.target_transform = target_transform
self.seed = seed
- # Load dataset infromation.
- essentials = _load_emnist_essentials()
- self.mapping = _augment_emnist_mapping(dict(essentials["mapping"]))
- self.inverse_mapping = {v: k for k, v in self.mapping.items()}
- self.num_classes = len(self.mapping)
- self.input_shape = essentials["input_shape"]
+ self._mapper = EmnistMapper()
+ self.input_shape = self._mapper.input_shape
+ self.num_classes = self._mapper.num_classes
# Placeholders
self.data = None
self.targets = None
+ # Load dataset.
+ self.load_emnist_dataset()
+
+ @property
+ def mapper(self) -> EmnistMapper:
+ """Returns the EmnistMapper."""
+ return self._mapper
+
def __len__(self) -> int:
"""Returns the length of the dataset."""
return len(self.data)
@@ -162,13 +225,18 @@ class EmnistDataset(Dataset):
return data, targets
+ @property
+ def __name__(self) -> str:
+ """Returns the name of the dataset."""
+ return "EmnistDataset"
+
def __repr__(self) -> str:
"""Returns information about the dataset."""
return (
"EMNIST Dataset\n"
f"Num classes: {self.num_classes}\n"
- f"Mapping: {self.mapping}\n"
f"Input shape: {self.input_shape}\n"
+ f"Mapping: {self.mapper.mapping}\n"
)
def _sample_to_balance(self) -> None:
@@ -217,118 +285,3 @@ class EmnistDataset(Dataset):
if self.subsample_fraction is not None:
self._subsample()
-
-
-class EmnistDataLoaders:
- """Class for Emnist DataLoaders."""
-
- def __init__(
- self,
- splits: List[str],
- sample_to_balance: bool = False,
- subsample_fraction: float = None,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- batch_size: int = 128,
- shuffle: bool = False,
- num_workers: int = 0,
- cuda: bool = True,
- seed: int = 4711,
- ) -> None:
- """Fetches DataLoaders for given split(s).
-
- Args:
- splits (List[str]): One or both of the dataset splits "train" and "val".
- sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected.
- Defaults to False.
- subsample_fraction (float): The fraction of the dataset will be loaded. If None or 0 the entire
- dataset will be loaded.
- transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a
- transformed version. E.g, transforms.RandomCrop. Defaults to None.
- target_transform (Optional[Callable]): A function/transform that takes in the target and
- transforms it. Defaults to None.
- batch_size (int): How many samples per batch to load. Defaults to 128.
- shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False.
- num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be
- loaded in the main process. Defaults to 0.
- cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning
- them. Defaults to True.
- seed (int): Seed for sampling.
-
- Raises:
- ValueError: If subsample_fraction is not None and outside the range (0, 1).
-
- """
- self.splits = splits
-
- if subsample_fraction is not None:
- if not 0.0 < subsample_fraction < 1.0:
- raise ValueError("The subsample fraction must be in (0, 1).")
-
- self.dataset_args = {
- "sample_to_balance": sample_to_balance,
- "subsample_fraction": subsample_fraction,
- "transform": transform,
- "target_transform": target_transform,
- "seed": seed,
- }
- self.batch_size = batch_size
- self.shuffle = shuffle
- self.num_workers = num_workers
- self.cuda = cuda
- self._data_loaders = self._load_data_loaders()
-
- def __repr__(self) -> str:
- """Returns information about the dataset."""
- return self._data_loaders[self.splits[0]].dataset.__repr__()
-
- @property
- def __name__(self) -> str:
- """Returns the name of the dataset."""
- return "Emnist"
-
- def __call__(self, split: str) -> DataLoader:
- """Returns the `split` DataLoader.
-
- Args:
- split (str): The dataset split, i.e. train or val.
-
- Returns:
- DataLoader: A PyTorch DataLoader.
-
- Raises:
- ValueError: If the split does not exist.
-
- """
- try:
- return self._data_loaders[split]
- except KeyError:
- raise ValueError(f"Split {split} does not exist.")
-
- def _load_data_loaders(self) -> Dict[str, DataLoader]:
- """Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders."""
- data_loaders = {}
-
- for split in ["train", "val"]:
- if split in self.splits:
-
- if split == "train":
- self.dataset_args["train"] = True
- else:
- self.dataset_args["train"] = False
-
- emnist_dataset = EmnistDataset(**self.dataset_args)
-
- emnist_dataset.load_emnist_dataset()
-
- data_loader = DataLoader(
- dataset=emnist_dataset,
- batch_size=self.batch_size,
- shuffle=self.shuffle,
- num_workers=self.num_workers,
- pin_memory=self.cuda,
- )
-
- data_loaders[split] = data_loader
-
- return data_loaders
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index 1c6e959..d64a991 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -12,10 +12,9 @@ from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, Normalize, ToTensor
from text_recognizer.datasets import (
- _augment_emnist_mapping,
- _load_emnist_essentials,
DATA_DIRNAME,
EmnistDataset,
+ EmnistMapper,
ESSENTIALS_FILENAME,
)
from text_recognizer.datasets.sentence_generator import SentenceGenerator
@@ -30,7 +29,6 @@ class EmnistLinesDataset(Dataset):
def __init__(
self,
train: bool = False,
- emnist: Optional[EmnistDataset] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
max_length: int = 34,
@@ -39,10 +37,9 @@ class EmnistLinesDataset(Dataset):
num_samples: int = 10000,
seed: int = 4711,
) -> None:
- """Short summary.
+ """Set attributes and loads the dataset.
Args:
- emnist (EmnistDataset): A EmnistDataset object.
train (bool): Flag for the filename. Defaults to False. Defaults to None.
transform (Optional[Callable]): The transform of the data. Defaults to None.
target_transform (Optional[Callable]): The transform of the target. Defaults to None.
@@ -54,7 +51,6 @@ class EmnistLinesDataset(Dataset):
"""
self.train = train
- self.emnist = emnist
self.transform = transform
if self.transform is None:
@@ -64,11 +60,10 @@ class EmnistLinesDataset(Dataset):
if self.target_transform is None:
self.target_transform = torch.tensor
- # Load emnist dataset infromation.
- essentials = _load_emnist_essentials()
- self.mapping = _augment_emnist_mapping(dict(essentials["mapping"]))
- self.num_classes = len(self.mapping)
- self.input_shape = essentials["input_shape"]
+ # Extract dataset information.
+ self._mapper = EmnistMapper()
+ self.input_shape = self._mapper.input_shape
+ self.num_classes = self._mapper.num_classes
self.max_length = max_length
self.min_overlap = min_overlap
@@ -81,10 +76,13 @@ class EmnistLinesDataset(Dataset):
self.output_shape = (self.max_length, self.num_classes)
self.seed = seed
- # Placeholders for the generated dataset.
+ # Placeholders for the dataset.
self.data = None
self.target = None
+ # Load dataset.
+ self._load_or_generate_data()
+
def __len__(self) -> int:
"""Returns the length of the dataset."""
return len(self.data)
@@ -104,7 +102,6 @@ class EmnistLinesDataset(Dataset):
if torch.is_tensor(index):
index = index.tolist()
- # data = np.array([self.data[index]])
data = self.data[index]
targets = self.targets[index]
@@ -116,6 +113,11 @@ class EmnistLinesDataset(Dataset):
return data, targets
+ @property
+ def __name__(self) -> str:
+ """Returns the name of the dataset."""
+ return "EmnistLinesDataset"
+
def __repr__(self) -> str:
"""Returns information about the dataset."""
return (
@@ -130,6 +132,11 @@ class EmnistLinesDataset(Dataset):
)
@property
+ def mapper(self) -> EmnistMapper:
+ """Returns the EmnistMapper."""
+ return self._mapper
+
+ @property
def data_filename(self) -> Path:
"""Path to the h5 file."""
filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt"
@@ -161,9 +168,10 @@ class EmnistLinesDataset(Dataset):
sentence_generator = SentenceGenerator(self.max_length)
# Load emnist dataset.
- self.emnist.load_emnist_dataset()
+ emnist = EmnistDataset(train=self.train, sample_to_balance=True)
+
samples_by_character = get_samples_by_character(
- self.emnist.data.numpy(), self.emnist.targets.numpy(), self.emnist.mapping,
+ emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping,
)
DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
@@ -332,94 +340,3 @@ def create_datasets(
num_samples=num,
)
emnist_lines._load_or_generate_data()
-
-
-class EmnistLinesDataLoaders:
- """Wrapper for a PyTorch Data loaders for the EMNIST lines dataset."""
-
- def __init__(
- self,
- splits: List[str],
- max_length: int = 34,
- min_overlap: float = 0,
- max_overlap: float = 0.33,
- num_samples: int = 10000,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- batch_size: int = 128,
- shuffle: bool = False,
- num_workers: int = 0,
- cuda: bool = True,
- seed: int = 4711,
- ) -> None:
- """Sets the data loader arguments."""
- self.splits = splits
- self.dataset_args = {
- "max_length": max_length,
- "min_overlap": min_overlap,
- "max_overlap": max_overlap,
- "num_samples": num_samples,
- "transform": transform,
- "target_transform": target_transform,
- "seed": seed,
- }
- self.batch_size = batch_size
- self.shuffle = shuffle
- self.num_workers = num_workers
- self.cuda = cuda
- self._data_loaders = self._load_data_loaders()
-
- def __repr__(self) -> str:
- """Returns information about the dataset."""
- return self._data_loaders[self.splits[0]].dataset.__repr__()
-
- @property
- def __name__(self) -> str:
- """Returns the name of the dataset."""
- return "EmnistLines"
-
- def __call__(self, split: str) -> DataLoader:
- """Returns the `split` DataLoader.
-
- Args:
- split (str): The dataset split, i.e. train or val.
-
- Returns:
- DataLoader: A PyTorch DataLoader.
-
- Raises:
- ValueError: If the split does not exist.
-
- """
- try:
- return self._data_loaders[split]
- except KeyError:
- raise ValueError(f"Split {split} does not exist.")
-
- def _load_data_loaders(self) -> Dict[str, DataLoader]:
- """Fetches the EMNIST Lines dataset and return a Dict of PyTorch DataLoaders."""
- data_loaders = {}
-
- for split in ["train", "val"]:
- if split in self.splits:
-
- if split == "train":
- self.dataset_args["train"] = True
- else:
- self.dataset_args["train"] = False
-
- emnist_lines_dataset = EmnistLinesDataset(**self.dataset_args)
-
- emnist_lines_dataset._load_or_generate_data()
-
- data_loader = DataLoader(
- dataset=emnist_lines_dataset,
- batch_size=self.batch_size,
- shuffle=self.shuffle,
- num_workers=self.num_workers,
- pin_memory=self.cuda,
- )
-
- data_loaders[split] = data_loader
-
- return data_loaders
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index 6668eef..321bc67 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -1,6 +1,10 @@
"""Util functions for datasets."""
+import importlib
+from typing import Callable, Dict, List, Type
+
import numpy as np
from PIL import Image
+from torch.utils.data import DataLoader, Dataset
class Transpose:
@@ -9,3 +13,59 @@ class Transpose:
def __call__(self, image: Image) -> np.ndarray:
"""Swaps axis."""
return np.array(image).swapaxes(0, 1)
+
+
+def fetch_data_loaders(
+ splits: List[str],
+ dataset: str,
+ dataset_args: Dict,
+ batch_size: int = 128,
+ shuffle: bool = False,
+ num_workers: int = 0,
+ cuda: bool = True,
+) -> Dict[str, DataLoader]:
+ """Fetches DataLoaders for given split(s) as a dictionary.
+
+ Loads the dataset class given, and loads it with the dataset arguments, for the number of splits specified. Then
+ calls the DataLoader. Added to a dictionary with the split as key and DataLoader as value.
+
+ Args:
+ splits (List[str]): One or both of the dataset splits "train" and "val".
+ dataset (str): The name of the dataset.
+ dataset_args (Dict): The dataset arguments.
+ batch_size (int): How many samples per batch to load. Defaults to 128.
+ shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False.
+ num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be
+ loaded in the main process. Defaults to 0.
+ cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning
+ them. Defaults to True.
+
+ Returns:
+ Dict[str, DataLoader]: Dictionary with split as key and PyTorch DataLoader as value.
+
+ """
+
+ def check_dataset_args(args: Dict, split: str) -> Dict:
+ args["train"] = True if split == "train" else False
+ return args
+
+ # Import dataset module.
+ datasets_module = importlib.import_module("text_recognizer.datasets")
+ dataset_ = getattr(datasets_module, dataset)
+
+ data_loaders = {}
+
+ for split in ["train", "val"]:
+ if split in splits:
+
+ data_loader = DataLoader(
+ dataset=dataset_(**check_dataset_args(dataset_args, split)),
+ batch_size=batch_size,
+ shuffle=shuffle,
+ num_workers=num_workers,
+ pin_memory=cuda,
+ )
+
+ data_loaders[split] = data_loader
+
+ return data_loaders