diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-09 23:24:02 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-09 23:24:02 +0200 |
commit | 53677be4ec14854ea4881b0d78730e0414c8dedd (patch) | |
tree | 56eaace5e9906c7d408b6a251ca100b5c8b4e991 /src/text_recognizer/datasets | |
parent | 125d5da5fb845d03bda91426e172bca7f537584a (diff) |
Working bash scripts etc.
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r-- | src/text_recognizer/datasets/__init__.py | 13 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 275 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_lines_dataset.py | 129 | ||||
-rw-r--r-- | src/text_recognizer/datasets/util.py | 60 |
4 files changed, 201 insertions, 276 deletions
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index 1b4cc59..05f74f6 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -1,29 +1,24 @@ """Dataset modules.""" from .emnist_dataset import ( - _augment_emnist_mapping, - _load_emnist_essentials, DATA_DIRNAME, - EmnistDataLoaders, EmnistDataset, + EmnistMapper, ESSENTIALS_FILENAME, ) from .emnist_lines_dataset import ( construct_image_from_string, - EmnistLinesDataLoaders, EmnistLinesDataset, get_samples_by_character, ) -from .util import Transpose +from .util import fetch_data_loaders, Transpose __all__ = [ - "_augment_emnist_mapping", - "_load_emnist_essentials", "construct_image_from_string", "DATA_DIRNAME", "EmnistDataset", - "EmnistDataLoaders", - "EmnistLinesDataLoaders", + "EmnistMapper", "EmnistLinesDataset", + "fetch_data_loaders", "get_samples_by_character", "Transpose", ] diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index f3d65ee..96f84e5 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -39,45 +39,101 @@ def download_emnist() -> None: save_emnist_essentials(dataset) -def _load_emnist_essentials() -> Dict: - """Load the EMNIST mapping.""" - with open(str(ESSENTIALS_FILENAME)) as f: - essentials = json.load(f) - return essentials - - -def _augment_emnist_mapping(mapping: Dict) -> Dict: - """Augment the mapping with extra symbols.""" - # Extra symbols in IAM dataset - extra_symbols = [ - " ", - "!", - '"', - "#", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "?", - ] - - # padding symbol - extra_symbols.append("_") - - max_key = max(mapping.keys()) - extra_mapping = {} - for i, symbol in enumerate(extra_symbols): - extra_mapping[max_key + 1 + i] = symbol - - return {**mapping, **extra_mapping} +class EmnistMapper: + """Mapper between network output to Emnist character.""" + + def __init__(self) -> None: + """Loads the emnist essentials file with the mapping and input shape.""" + self.essentials = self._load_emnist_essentials() + # Load dataset infromation. + self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"])) + self._inverse_mapping = {v: k for k, v in self.mapping.items()} + self._num_classes = len(self.mapping) + self._input_shape = self.essentials["input_shape"] + + def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]: + """Maps the token to emnist character or character index. + + If the token is an integer (index), the method will return the Emnist character corresponding to that index. + If the token is a str (Emnist character), the method will return the corresponding index for that character. + + Args: + token (Union[str, int, np.uint8]): Eihter a string or index (integer). + + Returns: + Union[str, int]: The mapping result. + + Raises: + KeyError: If the index or string does not exist in the mapping. + + """ + if (isinstance(token, np.uint8) or isinstance(token, int)) and int( + token + ) in self.mapping: + return self.mapping[int(token)] + elif isinstance(token, str) and token in self._inverse_mapping: + return self._inverse_mapping[token] + else: + raise KeyError(f"Token {token} does not exist in the mappings.") + + @property + def mapping(self) -> Dict: + """Returns the mapping between index and character.""" + return self._mapping + + @property + def inverse_mapping(self) -> Dict: + """Returns the mapping between character and index.""" + return self._inverse_mapping + + @property + def num_classes(self) -> int: + """Returns the number of classes in the dataset.""" + return self._num_classes + + @property + def input_shape(self) -> List[int]: + """Returns the input shape of the Emnist characters.""" + return self._input_shape + + def _load_emnist_essentials(self) -> Dict: + """Load the EMNIST mapping.""" + with open(str(ESSENTIALS_FILENAME)) as f: + essentials = json.load(f) + return essentials + + def _augment_emnist_mapping(self, mapping: Dict) -> Dict: + """Augment the mapping with extra symbols.""" + # Extra symbols in IAM dataset + extra_symbols = [ + " ", + "!", + '"', + "#", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "?", + ] + + # padding symbol + extra_symbols.append("_") + + max_key = max(mapping.keys()) + extra_mapping = {} + for i, symbol in enumerate(extra_symbols): + extra_mapping[max_key + 1 + i] = symbol + + return {**mapping, **extra_mapping} class EmnistDataset(Dataset): @@ -110,10 +166,12 @@ class EmnistDataset(Dataset): self.train = train self.sample_to_balance = sample_to_balance + if subsample_fraction is not None: if not 0.0 < subsample_fraction < 1.0: raise ValueError("The subsample fraction must be in (0, 1).") self.subsample_fraction = subsample_fraction + self.transform = transform if self.transform is None: self.transform = Compose([Transpose(), ToTensor()]) @@ -121,17 +179,22 @@ class EmnistDataset(Dataset): self.target_transform = target_transform self.seed = seed - # Load dataset infromation. - essentials = _load_emnist_essentials() - self.mapping = _augment_emnist_mapping(dict(essentials["mapping"])) - self.inverse_mapping = {v: k for k, v in self.mapping.items()} - self.num_classes = len(self.mapping) - self.input_shape = essentials["input_shape"] + self._mapper = EmnistMapper() + self.input_shape = self._mapper.input_shape + self.num_classes = self._mapper.num_classes # Placeholders self.data = None self.targets = None + # Load dataset. + self.load_emnist_dataset() + + @property + def mapper(self) -> EmnistMapper: + """Returns the EmnistMapper.""" + return self._mapper + def __len__(self) -> int: """Returns the length of the dataset.""" return len(self.data) @@ -162,13 +225,18 @@ class EmnistDataset(Dataset): return data, targets + @property + def __name__(self) -> str: + """Returns the name of the dataset.""" + return "EmnistDataset" + def __repr__(self) -> str: """Returns information about the dataset.""" return ( "EMNIST Dataset\n" f"Num classes: {self.num_classes}\n" - f"Mapping: {self.mapping}\n" f"Input shape: {self.input_shape}\n" + f"Mapping: {self.mapper.mapping}\n" ) def _sample_to_balance(self) -> None: @@ -217,118 +285,3 @@ class EmnistDataset(Dataset): if self.subsample_fraction is not None: self._subsample() - - -class EmnistDataLoaders: - """Class for Emnist DataLoaders.""" - - def __init__( - self, - splits: List[str], - sample_to_balance: bool = False, - subsample_fraction: float = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - batch_size: int = 128, - shuffle: bool = False, - num_workers: int = 0, - cuda: bool = True, - seed: int = 4711, - ) -> None: - """Fetches DataLoaders for given split(s). - - Args: - splits (List[str]): One or both of the dataset splits "train" and "val". - sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected. - Defaults to False. - subsample_fraction (float): The fraction of the dataset will be loaded. If None or 0 the entire - dataset will be loaded. - transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a - transformed version. E.g, transforms.RandomCrop. Defaults to None. - target_transform (Optional[Callable]): A function/transform that takes in the target and - transforms it. Defaults to None. - batch_size (int): How many samples per batch to load. Defaults to 128. - shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False. - num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be - loaded in the main process. Defaults to 0. - cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning - them. Defaults to True. - seed (int): Seed for sampling. - - Raises: - ValueError: If subsample_fraction is not None and outside the range (0, 1). - - """ - self.splits = splits - - if subsample_fraction is not None: - if not 0.0 < subsample_fraction < 1.0: - raise ValueError("The subsample fraction must be in (0, 1).") - - self.dataset_args = { - "sample_to_balance": sample_to_balance, - "subsample_fraction": subsample_fraction, - "transform": transform, - "target_transform": target_transform, - "seed": seed, - } - self.batch_size = batch_size - self.shuffle = shuffle - self.num_workers = num_workers - self.cuda = cuda - self._data_loaders = self._load_data_loaders() - - def __repr__(self) -> str: - """Returns information about the dataset.""" - return self._data_loaders[self.splits[0]].dataset.__repr__() - - @property - def __name__(self) -> str: - """Returns the name of the dataset.""" - return "Emnist" - - def __call__(self, split: str) -> DataLoader: - """Returns the `split` DataLoader. - - Args: - split (str): The dataset split, i.e. train or val. - - Returns: - DataLoader: A PyTorch DataLoader. - - Raises: - ValueError: If the split does not exist. - - """ - try: - return self._data_loaders[split] - except KeyError: - raise ValueError(f"Split {split} does not exist.") - - def _load_data_loaders(self) -> Dict[str, DataLoader]: - """Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders.""" - data_loaders = {} - - for split in ["train", "val"]: - if split in self.splits: - - if split == "train": - self.dataset_args["train"] = True - else: - self.dataset_args["train"] = False - - emnist_dataset = EmnistDataset(**self.dataset_args) - - emnist_dataset.load_emnist_dataset() - - data_loader = DataLoader( - dataset=emnist_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - pin_memory=self.cuda, - ) - - data_loaders[split] = data_loader - - return data_loaders diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py index 1c6e959..d64a991 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/src/text_recognizer/datasets/emnist_lines_dataset.py @@ -12,10 +12,9 @@ from torch.utils.data import DataLoader, Dataset from torchvision.transforms import Compose, Normalize, ToTensor from text_recognizer.datasets import ( - _augment_emnist_mapping, - _load_emnist_essentials, DATA_DIRNAME, EmnistDataset, + EmnistMapper, ESSENTIALS_FILENAME, ) from text_recognizer.datasets.sentence_generator import SentenceGenerator @@ -30,7 +29,6 @@ class EmnistLinesDataset(Dataset): def __init__( self, train: bool = False, - emnist: Optional[EmnistDataset] = None, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, max_length: int = 34, @@ -39,10 +37,9 @@ class EmnistLinesDataset(Dataset): num_samples: int = 10000, seed: int = 4711, ) -> None: - """Short summary. + """Set attributes and loads the dataset. Args: - emnist (EmnistDataset): A EmnistDataset object. train (bool): Flag for the filename. Defaults to False. Defaults to None. transform (Optional[Callable]): The transform of the data. Defaults to None. target_transform (Optional[Callable]): The transform of the target. Defaults to None. @@ -54,7 +51,6 @@ class EmnistLinesDataset(Dataset): """ self.train = train - self.emnist = emnist self.transform = transform if self.transform is None: @@ -64,11 +60,10 @@ class EmnistLinesDataset(Dataset): if self.target_transform is None: self.target_transform = torch.tensor - # Load emnist dataset infromation. - essentials = _load_emnist_essentials() - self.mapping = _augment_emnist_mapping(dict(essentials["mapping"])) - self.num_classes = len(self.mapping) - self.input_shape = essentials["input_shape"] + # Extract dataset information. + self._mapper = EmnistMapper() + self.input_shape = self._mapper.input_shape + self.num_classes = self._mapper.num_classes self.max_length = max_length self.min_overlap = min_overlap @@ -81,10 +76,13 @@ class EmnistLinesDataset(Dataset): self.output_shape = (self.max_length, self.num_classes) self.seed = seed - # Placeholders for the generated dataset. + # Placeholders for the dataset. self.data = None self.target = None + # Load dataset. + self._load_or_generate_data() + def __len__(self) -> int: """Returns the length of the dataset.""" return len(self.data) @@ -104,7 +102,6 @@ class EmnistLinesDataset(Dataset): if torch.is_tensor(index): index = index.tolist() - # data = np.array([self.data[index]]) data = self.data[index] targets = self.targets[index] @@ -116,6 +113,11 @@ class EmnistLinesDataset(Dataset): return data, targets + @property + def __name__(self) -> str: + """Returns the name of the dataset.""" + return "EmnistLinesDataset" + def __repr__(self) -> str: """Returns information about the dataset.""" return ( @@ -130,6 +132,11 @@ class EmnistLinesDataset(Dataset): ) @property + def mapper(self) -> EmnistMapper: + """Returns the EmnistMapper.""" + return self._mapper + + @property def data_filename(self) -> Path: """Path to the h5 file.""" filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt" @@ -161,9 +168,10 @@ class EmnistLinesDataset(Dataset): sentence_generator = SentenceGenerator(self.max_length) # Load emnist dataset. - self.emnist.load_emnist_dataset() + emnist = EmnistDataset(train=self.train, sample_to_balance=True) + samples_by_character = get_samples_by_character( - self.emnist.data.numpy(), self.emnist.targets.numpy(), self.emnist.mapping, + emnist.data.numpy(), emnist.targets.numpy(), self.mapper.mapping, ) DATA_DIRNAME.mkdir(parents=True, exist_ok=True) @@ -332,94 +340,3 @@ def create_datasets( num_samples=num, ) emnist_lines._load_or_generate_data() - - -class EmnistLinesDataLoaders: - """Wrapper for a PyTorch Data loaders for the EMNIST lines dataset.""" - - def __init__( - self, - splits: List[str], - max_length: int = 34, - min_overlap: float = 0, - max_overlap: float = 0.33, - num_samples: int = 10000, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - batch_size: int = 128, - shuffle: bool = False, - num_workers: int = 0, - cuda: bool = True, - seed: int = 4711, - ) -> None: - """Sets the data loader arguments.""" - self.splits = splits - self.dataset_args = { - "max_length": max_length, - "min_overlap": min_overlap, - "max_overlap": max_overlap, - "num_samples": num_samples, - "transform": transform, - "target_transform": target_transform, - "seed": seed, - } - self.batch_size = batch_size - self.shuffle = shuffle - self.num_workers = num_workers - self.cuda = cuda - self._data_loaders = self._load_data_loaders() - - def __repr__(self) -> str: - """Returns information about the dataset.""" - return self._data_loaders[self.splits[0]].dataset.__repr__() - - @property - def __name__(self) -> str: - """Returns the name of the dataset.""" - return "EmnistLines" - - def __call__(self, split: str) -> DataLoader: - """Returns the `split` DataLoader. - - Args: - split (str): The dataset split, i.e. train or val. - - Returns: - DataLoader: A PyTorch DataLoader. - - Raises: - ValueError: If the split does not exist. - - """ - try: - return self._data_loaders[split] - except KeyError: - raise ValueError(f"Split {split} does not exist.") - - def _load_data_loaders(self) -> Dict[str, DataLoader]: - """Fetches the EMNIST Lines dataset and return a Dict of PyTorch DataLoaders.""" - data_loaders = {} - - for split in ["train", "val"]: - if split in self.splits: - - if split == "train": - self.dataset_args["train"] = True - else: - self.dataset_args["train"] = False - - emnist_lines_dataset = EmnistLinesDataset(**self.dataset_args) - - emnist_lines_dataset._load_or_generate_data() - - data_loader = DataLoader( - dataset=emnist_lines_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - pin_memory=self.cuda, - ) - - data_loaders[split] = data_loader - - return data_loaders diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py index 6668eef..321bc67 100644 --- a/src/text_recognizer/datasets/util.py +++ b/src/text_recognizer/datasets/util.py @@ -1,6 +1,10 @@ """Util functions for datasets.""" +import importlib +from typing import Callable, Dict, List, Type + import numpy as np from PIL import Image +from torch.utils.data import DataLoader, Dataset class Transpose: @@ -9,3 +13,59 @@ class Transpose: def __call__(self, image: Image) -> np.ndarray: """Swaps axis.""" return np.array(image).swapaxes(0, 1) + + +def fetch_data_loaders( + splits: List[str], + dataset: str, + dataset_args: Dict, + batch_size: int = 128, + shuffle: bool = False, + num_workers: int = 0, + cuda: bool = True, +) -> Dict[str, DataLoader]: + """Fetches DataLoaders for given split(s) as a dictionary. + + Loads the dataset class given, and loads it with the dataset arguments, for the number of splits specified. Then + calls the DataLoader. Added to a dictionary with the split as key and DataLoader as value. + + Args: + splits (List[str]): One or both of the dataset splits "train" and "val". + dataset (str): The name of the dataset. + dataset_args (Dict): The dataset arguments. + batch_size (int): How many samples per batch to load. Defaults to 128. + shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False. + num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be + loaded in the main process. Defaults to 0. + cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning + them. Defaults to True. + + Returns: + Dict[str, DataLoader]: Dictionary with split as key and PyTorch DataLoader as value. + + """ + + def check_dataset_args(args: Dict, split: str) -> Dict: + args["train"] = True if split == "train" else False + return args + + # Import dataset module. + datasets_module = importlib.import_module("text_recognizer.datasets") + dataset_ = getattr(datasets_module, dataset) + + data_loaders = {} + + for split in ["train", "val"]: + if split in splits: + + data_loader = DataLoader( + dataset=dataset_(**check_dataset_args(dataset_args, split)), + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + pin_memory=cuda, + ) + + data_loaders[split] = data_loader + + return data_loaders |