diff options
Diffstat (limited to 'src/text_recognizer')
-rw-r--r-- | src/text_recognizer/datasets/__init__.py | 7 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 31 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 120 |
3 files changed, 129 insertions, 29 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index a8c46c4..1b4cc59 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -1,21 +1,28 @@ """Dataset modules.""" from .emnist_dataset import ( + _augment_emnist_mapping, + _load_emnist_essentials, DATA_DIRNAME, EmnistDataLoaders, EmnistDataset, + ESSENTIALS_FILENAME, ) from .emnist_lines_dataset import ( construct_image_from_string, + EmnistLinesDataLoaders, EmnistLinesDataset, get_samples_by_character, ) from .util import Transpose __all__ = [ + "_augment_emnist_mapping", + "_load_emnist_essentials", "construct_image_from_string", "DATA_DIRNAME", "EmnistDataset", "EmnistDataLoaders", + "EmnistLinesDataLoaders", "EmnistLinesDataset", "get_samples_by_character", "Transpose", diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index 525df95..f3d65ee 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -260,21 +260,23 @@ class EmnistDataLoaders: """ self.splits = splits - self.sample_to_balance = sample_to_balance if subsample_fraction is not None: if not 0.0 < subsample_fraction < 1.0: raise ValueError("The subsample fraction must be in (0, 1).") - self.subsample_fraction = subsample_fraction - self.transform = transform - self.target_transform = target_transform + self.dataset_args = { + "sample_to_balance": sample_to_balance, + "subsample_fraction": subsample_fraction, + "transform": transform, + "target_transform": target_transform, + "seed": seed, + } self.batch_size = batch_size self.shuffle = shuffle self.num_workers = num_workers self.cuda = cuda - self.seed = seed - self._data_loaders = self._fetch_emnist_data_loaders() + self._data_loaders = self._load_data_loaders() def __repr__(self) -> str: """Returns information about the dataset.""" @@ -303,7 +305,7 @@ class EmnistDataLoaders: except KeyError: raise ValueError(f"Split {split} does not exist.") - def _fetch_emnist_data_loaders(self) -> Dict[str, DataLoader]: + def _load_data_loaders(self) -> Dict[str, DataLoader]: """Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders.""" data_loaders = {} @@ -311,18 +313,11 @@ class EmnistDataLoaders: if split in self.splits: if split == "train": - train = True + self.dataset_args["train"] = True else: - train = False - - emnist_dataset = EmnistDataset( - train=train, - sample_to_balance=self.sample_to_balance, - subsample_fraction=self.subsample_fraction, - transform=self.transform, - target_transform=self.target_transform, - seed=self.seed, - ) + self.dataset_args["train"] = False + + emnist_dataset = EmnistDataset(**self.dataset_args) emnist_dataset.load_emnist_dataset() diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index 4d8b646..1c6e959 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -8,17 +8,20 @@ import h5py from loguru import logger import numpy as np import torch -from torch.utils.data import Dataset +from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose, Normalize, ToTensor -from text_recognizer.datasets import DATA_DIRNAME, EmnistDataset +from text_recognizer.datasets import ( + _augment_emnist_mapping, + _load_emnist_essentials, + DATA_DIRNAME, + EmnistDataset, + ESSENTIALS_FILENAME, +) from text_recognizer.datasets.sentence_generator import SentenceGenerator from text_recognizer.datasets.util import Transpose DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines" -ESSENTIALS_FILENAME = ( - Path(__file__).resolve().parents[0] / "emnist_lines_essentials.json" -) class EmnistLinesDataset(Dataset): @@ -26,8 +29,8 @@ class EmnistLinesDataset(Dataset): def __init__( self, - emnist: EmnistDataset, train: bool = False, + emnist: Optional[EmnistDataset] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, max_length: int = 34, @@ -40,7 +43,7 @@ class EmnistLinesDataset(Dataset): Args: emnist (EmnistDataset): A EmnistDataset object. - train (bool): Flag for the filename. Defaults to False. + train (bool): Flag for the filename. Defaults to False. Defaults to None. transform (Optional[Callable]): The transform of the data. Defaults to None. target_transform (Optional[Callable]): The transform of the target. Defaults to None. max_length (int): The maximum number of characters. Defaults to 34. @@ -61,15 +64,19 @@ class EmnistLinesDataset(Dataset): if self.target_transform is None: self.target_transform = torch.tensor - self.mapping = self.emnist.mapping - self.num_classes = self.emnist.num_classes + # Load emnist dataset infromation. + essentials = _load_emnist_essentials() + self.mapping = _augment_emnist_mapping(dict(essentials["mapping"])) + self.num_classes = len(self.mapping) + self.input_shape = essentials["input_shape"] + self.max_length = max_length self.min_overlap = min_overlap self.max_overlap = max_overlap self.num_samples = num_samples self.input_shape = ( - self.emnist.input_shape[0], - self.emnist.input_shape[1] * self.max_length, + self.input_shape[0], + self.input_shape[1] * self.max_length, ) self.output_shape = (self.max_length, self.num_classes) self.seed = seed @@ -325,3 +332,94 @@ def create_datasets( num_samples=num, ) emnist_lines._load_or_generate_data() + + +class EmnistLinesDataLoaders: + """Wrapper for a PyTorch Data loaders for the EMNIST lines dataset.""" + + def __init__( + self, + splits: List[str], + max_length: int = 34, + min_overlap: float = 0, + max_overlap: float = 0.33, + num_samples: int = 10000, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + batch_size: int = 128, + shuffle: bool = False, + num_workers: int = 0, + cuda: bool = True, + seed: int = 4711, + ) -> None: + """Sets the data loader arguments.""" + self.splits = splits + self.dataset_args = { + "max_length": max_length, + "min_overlap": min_overlap, + "max_overlap": max_overlap, + "num_samples": num_samples, + "transform": transform, + "target_transform": target_transform, + "seed": seed, + } + self.batch_size = batch_size + self.shuffle = shuffle + self.num_workers = num_workers + self.cuda = cuda + self._data_loaders = self._load_data_loaders() + + def __repr__(self) -> str: + """Returns information about the dataset.""" + return self._data_loaders[self.splits[0]].dataset.__repr__() + + @property + def __name__(self) -> str: + """Returns the name of the dataset.""" + return "EmnistLines" + + def __call__(self, split: str) -> DataLoader: + """Returns the `split` DataLoader. + + Args: + split (str): The dataset split, i.e. train or val. + + Returns: + DataLoader: A PyTorch DataLoader. + + Raises: + ValueError: If the split does not exist. + + """ + try: + return self._data_loaders[split] + except KeyError: + raise ValueError(f"Split {split} does not exist.") + + def _load_data_loaders(self) -> Dict[str, DataLoader]: + """Fetches the EMNIST Lines dataset and return a Dict of PyTorch DataLoaders.""" + data_loaders = {} + + for split in ["train", "val"]: + if split in self.splits: + + if split == "train": + self.dataset_args["train"] = True + else: + self.dataset_args["train"] = False + + emnist_lines_dataset = EmnistLinesDataset(**self.dataset_args) + + emnist_lines_dataset._load_or_generate_data() + + data_loader = DataLoader( + dataset=emnist_lines_dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + pin_memory=self.cuda, + ) + + data_loaders[split] = data_loader + + return data_loaders |