diff options
Diffstat (limited to 'src/text_recognizer/datasets/emnist_lines_dataset.py')
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 120 |
1 files changed, 109 insertions, 11 deletions
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index 4d8b646..1c6e959 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -8,17 +8,20 @@ import h5py from loguru import logger import numpy as np import torch -from torch.utils.data import Dataset +from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose, Normalize, ToTensor -from text_recognizer.datasets import DATA_DIRNAME, EmnistDataset +from text_recognizer.datasets import ( + _augment_emnist_mapping, + _load_emnist_essentials, + DATA_DIRNAME, + EmnistDataset, + ESSENTIALS_FILENAME, +) from text_recognizer.datasets.sentence_generator import SentenceGenerator from text_recognizer.datasets.util import Transpose DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines" -ESSENTIALS_FILENAME = ( - Path(__file__).resolve().parents[0] / "emnist_lines_essentials.json" -) class EmnistLinesDataset(Dataset): @@ -26,8 +29,8 @@ class EmnistLinesDataset(Dataset): def __init__( self, - emnist: EmnistDataset, train: bool = False, + emnist: Optional[EmnistDataset] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, max_length: int = 34, @@ -40,7 +43,7 @@ class EmnistLinesDataset(Dataset): Args: emnist (EmnistDataset): A EmnistDataset object. - train (bool): Flag for the filename. Defaults to False. + train (bool): Flag for the filename. Defaults to False. Defaults to None. transform (Optional[Callable]): The transform of the data. Defaults to None. target_transform (Optional[Callable]): The transform of the target. Defaults to None. max_length (int): The maximum number of characters. Defaults to 34. @@ -61,15 +64,19 @@ class EmnistLinesDataset(Dataset): if self.target_transform is None: self.target_transform = torch.tensor - self.mapping = self.emnist.mapping - self.num_classes = self.emnist.num_classes + # Load emnist dataset infromation. + essentials = _load_emnist_essentials() + self.mapping = _augment_emnist_mapping(dict(essentials["mapping"])) + self.num_classes = len(self.mapping) + self.input_shape = essentials["input_shape"] + self.max_length = max_length self.min_overlap = min_overlap self.max_overlap = max_overlap self.num_samples = num_samples self.input_shape = ( - self.emnist.input_shape[0], - self.emnist.input_shape[1] * self.max_length, + self.input_shape[0], + self.input_shape[1] * self.max_length, ) self.output_shape = (self.max_length, self.num_classes) self.seed = seed @@ -325,3 +332,94 @@ def create_datasets( num_samples=num, ) emnist_lines._load_or_generate_data() + + +class EmnistLinesDataLoaders: + """Wrapper for a PyTorch Data loaders for the EMNIST lines dataset.""" + + def __init__( + self, + splits: List[str], + max_length: int = 34, + min_overlap: float = 0, + max_overlap: float = 0.33, + num_samples: int = 10000, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + batch_size: int = 128, + shuffle: bool = False, + num_workers: int = 0, + cuda: bool = True, + seed: int = 4711, + ) -> None: + """Sets the data loader arguments.""" + self.splits = splits + self.dataset_args = { + "max_length": max_length, + "min_overlap": min_overlap, + "max_overlap": max_overlap, + "num_samples": num_samples, + "transform": transform, + "target_transform": target_transform, + "seed": seed, + } + self.batch_size = batch_size + self.shuffle = shuffle + self.num_workers = num_workers + self.cuda = cuda + self._data_loaders = self._load_data_loaders() + + def __repr__(self) -> str: + """Returns information about the dataset.""" + return self._data_loaders[self.splits[0]].dataset.__repr__() + + @property + def __name__(self) -> str: + """Returns the name of the dataset.""" + return "EmnistLines" + + def __call__(self, split: str) -> DataLoader: + """Returns the `split` DataLoader. + + Args: + split (str): The dataset split, i.e. train or val. + + Returns: + DataLoader: A PyTorch DataLoader. + + Raises: + ValueError: If the split does not exist. + + """ + try: + return self._data_loaders[split] + except KeyError: + raise ValueError(f"Split {split} does not exist.") + + def _load_data_loaders(self) -> Dict[str, DataLoader]: + """Fetches the EMNIST Lines dataset and return a Dict of PyTorch DataLoaders.""" + data_loaders = {} + + for split in ["train", "val"]: + if split in self.splits: + + if split == "train": + self.dataset_args["train"] = True + else: + self.dataset_args["train"] = False + + emnist_lines_dataset = EmnistLinesDataset(**self.dataset_args) + + emnist_lines_dataset._load_or_generate_data() + + data_loader = DataLoader( + dataset=emnist_lines_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + pin_memory=self.cuda, + ) + + data_loaders[split] = data_loader + + return data_loaders |