diff options
Diffstat (limited to 'text_recognizer/data/emnist.py')
-rw-r--r-- | text_recognizer/data/emnist.py | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 3e10b5f..eda490a 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -1,6 +1,6 @@ """EMNIST dataset: downloads it from FSDL aws url if not present.""" from pathlib import Path -from typing import Dict, List, Sequence, Tuple +from typing import Dict, List, Optional, Sequence, Tuple import json import os import shutil @@ -52,7 +52,7 @@ class EMNIST(BaseDataModule): self.data_val = None self.data_test = None self.transform = transforms.Compose([transforms.ToTensor()]) - self.dims = (1, * self.input_shape) + self.dims = (1, *self.input_shape) self.output_dims = (1,) def prepare_data(self) -> None: @@ -95,13 +95,17 @@ class EMNIST(BaseDataModule): return basic + data -def emnist_mapping() -> Tuple[List, Dict[str, int], List[int]]: +def emnist_mapping( + extra_symbols: Optional[List[str]], +) -> Tuple[List, Dict[str, int], List[int]]: """Return the EMNIST mapping.""" if not ESSENTIALS_FILENAME.exists(): download_and_process_emnist() with ESSENTIALS_FILENAME.open() as f: essentials = json.load(f) mapping = list(essentials["characters"]) + if extra_symbols is not None: + mapping += extra_symbols inverse_mapping = {v: k for k, v in enumerate(mapping)} input_shape = essentials["input_shape"] return mapping, inverse_mapping, input_shape |