diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-03 21:59:07 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-03 21:59:07 +0200 |
commit | 07f2cc3665a1a60efe8ed8073cad6ac4f344b2c2 (patch) | |
tree | d24ae8e3b9b39bfcfb3b850b30cb966eb3b064a7 /text_recognizer/data/emnist.py | |
parent | 3196144ec99e803cef218295ddea592748931c57 (diff) |
Add IAM paragraphs dataset
Diffstat (limited to 'text_recognizer/data/emnist.py')
-rw-r--r-- | text_recognizer/data/emnist.py | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 3e10b5f..eda490a 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -1,6 +1,6 @@ """EMNIST dataset: downloads it from FSDL aws url if not present.""" from pathlib import Path -from typing import Dict, List, Sequence, Tuple +from typing import Dict, List, Optional, Sequence, Tuple import json import os import shutil @@ -52,7 +52,7 @@ class EMNIST(BaseDataModule): self.data_val = None self.data_test = None self.transform = transforms.Compose([transforms.ToTensor()]) - self.dims = (1, * self.input_shape) + self.dims = (1, *self.input_shape) self.output_dims = (1,) def prepare_data(self) -> None: @@ -95,13 +95,17 @@ class EMNIST(BaseDataModule): return basic + data -def emnist_mapping() -> Tuple[List, Dict[str, int], List[int]]: +def emnist_mapping( + extra_symbols: Optional[List[str]], +) -> Tuple[List, Dict[str, int], List[int]]: """Return the EMNIST mapping.""" if not ESSENTIALS_FILENAME.exists(): download_and_process_emnist() with ESSENTIALS_FILENAME.open() as f: essentials = json.load(f) mapping = list(essentials["characters"]) + if extra_symbols is not None: + mapping += extra_symbols inverse_mapping = {v: k for k, v in enumerate(mapping)} input_shape = essentials["input_shape"] return mapping, inverse_mapping, input_shape |