summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/emnist_lines_dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/emnist_lines_dataset.py')
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py120
1 files changed, 109 insertions, 11 deletions
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index 4d8b646..1c6e959 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -8,17 +8,20 @@ import h5py
from loguru import logger
import numpy as np
import torch
-from torch.utils.data import Dataset
+from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import Compose, Normalize, ToTensor
-from text_recognizer.datasets import DATA_DIRNAME, EmnistDataset
+from text_recognizer.datasets import (
+ _augment_emnist_mapping,
+ _load_emnist_essentials,
+ DATA_DIRNAME,
+ EmnistDataset,
+ ESSENTIALS_FILENAME,
+)
from text_recognizer.datasets.sentence_generator import SentenceGenerator
from text_recognizer.datasets.util import Transpose
DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines"
-ESSENTIALS_FILENAME = (
- Path(__file__).resolve().parents[0] / "emnist_lines_essentials.json"
-)
class EmnistLinesDataset(Dataset):
@@ -26,8 +29,8 @@ class EmnistLinesDataset(Dataset):
def __init__(
self,
- emnist: EmnistDataset,
train: bool = False,
+ emnist: Optional[EmnistDataset] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
max_length: int = 34,
@@ -40,7 +43,7 @@ class EmnistLinesDataset(Dataset):
Args:
emnist (EmnistDataset): A EmnistDataset object.
- train (bool): Flag for the filename. Defaults to False.
+ train (bool): Flag for the filename. Defaults to False. Defaults to None.
transform (Optional[Callable]): The transform of the data. Defaults to None.
target_transform (Optional[Callable]): The transform of the target. Defaults to None.
max_length (int): The maximum number of characters. Defaults to 34.
@@ -61,15 +64,19 @@ class EmnistLinesDataset(Dataset):
if self.target_transform is None:
self.target_transform = torch.tensor
- self.mapping = self.emnist.mapping
- self.num_classes = self.emnist.num_classes
+ # Load emnist dataset infromation.
+ essentials = _load_emnist_essentials()
+ self.mapping = _augment_emnist_mapping(dict(essentials["mapping"]))
+ self.num_classes = len(self.mapping)
+ self.input_shape = essentials["input_shape"]
+
self.max_length = max_length
self.min_overlap = min_overlap
self.max_overlap = max_overlap
self.num_samples = num_samples
self.input_shape = (
- self.emnist.input_shape[0],
- self.emnist.input_shape[1] * self.max_length,
+ self.input_shape[0],
+ self.input_shape[1] * self.max_length,
)
self.output_shape = (self.max_length, self.num_classes)
self.seed = seed
@@ -325,3 +332,94 @@ def create_datasets(
num_samples=num,
)
emnist_lines._load_or_generate_data()
+
+
+class EmnistLinesDataLoaders:
+ """Wrapper for a PyTorch Data loaders for the EMNIST lines dataset."""
+
+ def __init__(
+ self,
+ splits: List[str],
+ max_length: int = 34,
+ min_overlap: float = 0,
+ max_overlap: float = 0.33,
+ num_samples: int = 10000,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ batch_size: int = 128,
+ shuffle: bool = False,
+ num_workers: int = 0,
+ cuda: bool = True,
+ seed: int = 4711,
+ ) -> None:
+ """Sets the data loader arguments."""
+ self.splits = splits
+ self.dataset_args = {
+ "max_length": max_length,
+ "min_overlap": min_overlap,
+ "max_overlap": max_overlap,
+ "num_samples": num_samples,
+ "transform": transform,
+ "target_transform": target_transform,
+ "seed": seed,
+ }
+ self.batch_size = batch_size
+ self.shuffle = shuffle
+ self.num_workers = num_workers
+ self.cuda = cuda
+ self._data_loaders = self._load_data_loaders()
+
+ def __repr__(self) -> str:
+ """Returns information about the dataset."""
+ return self._data_loaders[self.splits[0]].dataset.__repr__()
+
+ @property
+ def __name__(self) -> str:
+ """Returns the name of the dataset."""
+ return "EmnistLines"
+
+ def __call__(self, split: str) -> DataLoader:
+ """Returns the `split` DataLoader.
+
+ Args:
+ split (str): The dataset split, i.e. train or val.
+
+ Returns:
+ DataLoader: A PyTorch DataLoader.
+
+ Raises:
+ ValueError: If the split does not exist.
+
+ """
+ try:
+ return self._data_loaders[split]
+ except KeyError:
+ raise ValueError(f"Split {split} does not exist.")
+
+ def _load_data_loaders(self) -> Dict[str, DataLoader]:
+ """Fetches the EMNIST Lines dataset and return a Dict of PyTorch DataLoaders."""
+ data_loaders = {}
+
+ for split in ["train", "val"]:
+ if split in self.splits:
+
+ if split == "train":
+ self.dataset_args["train"] = True
+ else:
+ self.dataset_args["train"] = False
+
+ emnist_lines_dataset = EmnistLinesDataset(**self.dataset_args)
+
+ emnist_lines_dataset._load_or_generate_data()
+
+ data_loader = DataLoader(
+ dataset=emnist_lines_dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ pin_memory=self.cuda,
+ )
+
+ data_loaders[split] = data_loader
+
+ return data_loaders