diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-09 23:24:02 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-09 23:24:02 +0200 |
commit | 53677be4ec14854ea4881b0d78730e0414c8dedd (patch) | |
tree | 56eaace5e9906c7d408b6a251ca100b5c8b4e991 /src/text_recognizer/datasets/emnist_dataset.py | |
parent | 125d5da5fb845d03bda91426e172bca7f537584a (diff) |
Working bash scripts etc.
Diffstat (limited to 'src/text_recognizer/datasets/emnist_dataset.py')
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 275 |
1 files changed, 114 insertions, 161 deletions
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index f3d65ee..96f84e5 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -39,45 +39,101 @@ def download_emnist() -> None: save_emnist_essentials(dataset) -def _load_emnist_essentials() -> Dict: - """Load the EMNIST mapping.""" - with open(str(ESSENTIALS_FILENAME)) as f: - essentials = json.load(f) - return essentials - - -def _augment_emnist_mapping(mapping: Dict) -> Dict: - """Augment the mapping with extra symbols.""" - # Extra symbols in IAM dataset - extra_symbols = [ - " ", - "!", - '"', - "#", - "&", - "'", - "(", - ")", - "*", - "+", - ",", - "-", - ".", - "/", - ":", - ";", - "?", - ] - - # padding symbol - extra_symbols.append("_") - - max_key = max(mapping.keys()) - extra_mapping = {} - for i, symbol in enumerate(extra_symbols): - extra_mapping[max_key + 1 + i] = symbol - - return {**mapping, **extra_mapping} +class EmnistMapper: + """Mapper between network output to Emnist character.""" + + def __init__(self) -> None: + """Loads the emnist essentials file with the mapping and input shape.""" + self.essentials = self._load_emnist_essentials() + # Load dataset infromation. + self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"])) + self._inverse_mapping = {v: k for k, v in self.mapping.items()} + self._num_classes = len(self.mapping) + self._input_shape = self.essentials["input_shape"] + + def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]: + """Maps the token to emnist character or character index. + + If the token is an integer (index), the method will return the Emnist character corresponding to that index. + If the token is a str (Emnist character), the method will return the corresponding index for that character. + + Args: + token (Union[str, int, np.uint8]): Eihter a string or index (integer). + + Returns: + Union[str, int]: The mapping result. + + Raises: + KeyError: If the index or string does not exist in the mapping. + + """ + if (isinstance(token, np.uint8) or isinstance(token, int)) and int( + token + ) in self.mapping: + return self.mapping[int(token)] + elif isinstance(token, str) and token in self._inverse_mapping: + return self._inverse_mapping[token] + else: + raise KeyError(f"Token {token} does not exist in the mappings.") + + @property + def mapping(self) -> Dict: + """Returns the mapping between index and character.""" + return self._mapping + + @property + def inverse_mapping(self) -> Dict: + """Returns the mapping between character and index.""" + return self._inverse_mapping + + @property + def num_classes(self) -> int: + """Returns the number of classes in the dataset.""" + return self._num_classes + + @property + def input_shape(self) -> List[int]: + """Returns the input shape of the Emnist characters.""" + return self._input_shape + + def _load_emnist_essentials(self) -> Dict: + """Load the EMNIST mapping.""" + with open(str(ESSENTIALS_FILENAME)) as f: + essentials = json.load(f) + return essentials + + def _augment_emnist_mapping(self, mapping: Dict) -> Dict: + """Augment the mapping with extra symbols.""" + # Extra symbols in IAM dataset + extra_symbols = [ + " ", + "!", + '"', + "#", + "&", + "'", + "(", + ")", + "*", + "+", + ",", + "-", + ".", + "/", + ":", + ";", + "?", + ] + + # padding symbol + extra_symbols.append("_") + + max_key = max(mapping.keys()) + extra_mapping = {} + for i, symbol in enumerate(extra_symbols): + extra_mapping[max_key + 1 + i] = symbol + + return {**mapping, **extra_mapping} class EmnistDataset(Dataset): @@ -110,10 +166,12 @@ class EmnistDataset(Dataset): self.train = train self.sample_to_balance = sample_to_balance + if subsample_fraction is not None: if not 0.0 < subsample_fraction < 1.0: raise ValueError("The subsample fraction must be in (0, 1).") self.subsample_fraction = subsample_fraction + self.transform = transform if self.transform is None: self.transform = Compose([Transpose(), ToTensor()]) @@ -121,17 +179,22 @@ class EmnistDataset(Dataset): self.target_transform = target_transform self.seed = seed - # Load dataset infromation. - essentials = _load_emnist_essentials() - self.mapping = _augment_emnist_mapping(dict(essentials["mapping"])) - self.inverse_mapping = {v: k for k, v in self.mapping.items()} - self.num_classes = len(self.mapping) - self.input_shape = essentials["input_shape"] + self._mapper = EmnistMapper() + self.input_shape = self._mapper.input_shape + self.num_classes = self._mapper.num_classes # Placeholders self.data = None self.targets = None + # Load dataset. + self.load_emnist_dataset() + + @property + def mapper(self) -> EmnistMapper: + """Returns the EmnistMapper.""" + return self._mapper + def __len__(self) -> int: """Returns the length of the dataset.""" return len(self.data) @@ -162,13 +225,18 @@ class EmnistDataset(Dataset): return data, targets + @property + def __name__(self) -> str: + """Returns the name of the dataset.""" + return "EmnistDataset" + def __repr__(self) -> str: """Returns information about the dataset.""" return ( "EMNIST Dataset\n" f"Num classes: {self.num_classes}\n" - f"Mapping: {self.mapping}\n" f"Input shape: {self.input_shape}\n" + f"Mapping: {self.mapper.mapping}\n" ) def _sample_to_balance(self) -> None: @@ -217,118 +285,3 @@ class EmnistDataset(Dataset): if self.subsample_fraction is not None: self._subsample() - - -class EmnistDataLoaders: - """Class for Emnist DataLoaders.""" - - def __init__( - self, - splits: List[str], - sample_to_balance: bool = False, - subsample_fraction: float = None, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - batch_size: int = 128, - shuffle: bool = False, - num_workers: int = 0, - cuda: bool = True, - seed: int = 4711, - ) -> None: - """Fetches DataLoaders for given split(s). - - Args: - splits (List[str]): One or both of the dataset splits "train" and "val". - sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected. - Defaults to False. - subsample_fraction (float): The fraction of the dataset will be loaded. If None or 0 the entire - dataset will be loaded. - transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a - transformed version. E.g, transforms.RandomCrop. Defaults to None. - target_transform (Optional[Callable]): A function/transform that takes in the target and - transforms it. Defaults to None. - batch_size (int): How many samples per batch to load. Defaults to 128. - shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False. - num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be - loaded in the main process. Defaults to 0. - cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning - them. Defaults to True. - seed (int): Seed for sampling. - - Raises: - ValueError: If subsample_fraction is not None and outside the range (0, 1). - - """ - self.splits = splits - - if subsample_fraction is not None: - if not 0.0 < subsample_fraction < 1.0: - raise ValueError("The subsample fraction must be in (0, 1).") - - self.dataset_args = { - "sample_to_balance": sample_to_balance, - "subsample_fraction": subsample_fraction, - "transform": transform, - "target_transform": target_transform, - "seed": seed, - } - self.batch_size = batch_size - self.shuffle = shuffle - self.num_workers = num_workers - self.cuda = cuda - self._data_loaders = self._load_data_loaders() - - def __repr__(self) -> str: - """Returns information about the dataset.""" - return self._data_loaders[self.splits[0]].dataset.__repr__() - - @property - def __name__(self) -> str: - """Returns the name of the dataset.""" - return "Emnist" - - def __call__(self, split: str) -> DataLoader: - """Returns the `split` DataLoader. - - Args: - split (str): The dataset split, i.e. train or val. - - Returns: - DataLoader: A PyTorch DataLoader. - - Raises: - ValueError: If the split does not exist. - - """ - try: - return self._data_loaders[split] - except KeyError: - raise ValueError(f"Split {split} does not exist.") - - def _load_data_loaders(self) -> Dict[str, DataLoader]: - """Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders.""" - data_loaders = {} - - for split in ["train", "val"]: - if split in self.splits: - - if split == "train": - self.dataset_args["train"] = True - else: - self.dataset_args["train"] = False - - emnist_dataset = EmnistDataset(**self.dataset_args) - - emnist_dataset.load_emnist_dataset() - - data_loader = DataLoader( - dataset=emnist_dataset, - batch_size=self.batch_size, - shuffle=self.shuffle, - num_workers=self.num_workers, - pin_memory=self.cuda, - ) - - data_loaders[split] = data_loader - - return data_loaders |