summaryrefslogtreecommitdiff
path: root/text_recognizer/data/emnist.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/emnist.py')
-rw-r--r--text_recognizer/data/emnist.py10
1 files changed, 7 insertions, 3 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 3e10b5f..eda490a 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -1,6 +1,6 @@
"""EMNIST dataset: downloads it from FSDL aws url if not present."""
from pathlib import Path
-from typing import Dict, List, Sequence, Tuple
+from typing import Dict, List, Optional, Sequence, Tuple
import json
import os
import shutil
@@ -52,7 +52,7 @@ class EMNIST(BaseDataModule):
self.data_val = None
self.data_test = None
self.transform = transforms.Compose([transforms.ToTensor()])
- self.dims = (1, * self.input_shape)
+ self.dims = (1, *self.input_shape)
self.output_dims = (1,)
def prepare_data(self) -> None:
@@ -95,13 +95,17 @@ class EMNIST(BaseDataModule):
return basic + data
-def emnist_mapping() -> Tuple[List, Dict[str, int], List[int]]:
+def emnist_mapping(
+ extra_symbols: Optional[List[str]],
+) -> Tuple[List, Dict[str, int], List[int]]:
"""Return the EMNIST mapping."""
if not ESSENTIALS_FILENAME.exists():
download_and_process_emnist()
with ESSENTIALS_FILENAME.open() as f:
essentials = json.load(f)
mapping = list(essentials["characters"])
+ if extra_symbols is not None:
+ mapping += extra_symbols
inverse_mapping = {v: k for k, v in enumerate(mapping)}
input_shape = essentials["input_shape"]
return mapping, inverse_mapping, input_shape