summaryrefslogtreecommitdiff
path: root/text_recognizer/data/emnist_mapping.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-08 19:59:55 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-08 19:59:55 +0200
commit240f5e9f20032e82515fa66ce784619527d1041e (patch)
treeb002d28bbfc9abe9b6af090f7db60bea0aeed6e8 /text_recognizer/data/emnist_mapping.py
parentd12f70402371dda586d457af2a3df7fb5b3130ad (diff)
Add VQGAN and loss function
Diffstat (limited to 'text_recognizer/data/emnist_mapping.py')
-rw-r--r--text_recognizer/data/emnist_mapping.py6
1 files changed, 5 insertions, 1 deletions
diff --git a/text_recognizer/data/emnist_mapping.py b/text_recognizer/data/emnist_mapping.py
index 3e91594..4406db7 100644
--- a/text_recognizer/data/emnist_mapping.py
+++ b/text_recognizer/data/emnist_mapping.py
@@ -9,7 +9,9 @@ from text_recognizer.data.emnist import emnist_mapping
class EmnistMapping(AbstractMapping):
- def __init__(self, extra_symbols: Optional[Set[str]] = None, lower: bool = True) -> None:
+ def __init__(
+ self, extra_symbols: Optional[Set[str]] = None, lower: bool = True
+ ) -> None:
self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None
self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(
self.extra_symbols
@@ -20,10 +22,12 @@ class EmnistMapping(AbstractMapping):
def _to_lower(self) -> None:
"""Converts mapping to lowercase letters only."""
+
def _filter(x: int) -> int:
if 40 <= x:
return x - 26
return x
+
self.inverse_mapping = {v: _filter(k) for k, v in enumerate(self.mapping)}
self.mapping = [c for c in self.mapping if not c.isupper()]