summaryrefslogtreecommitdiff
path: root/text_recognizer/data/emnist_mapping.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-03 19:14:16 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-03 19:14:16 +0200
commit65d5f6c694e73792e40ed693a1381a792da8d277 (patch)
treec97067914c18b47108188c51294abd40bb1d6481 /text_recognizer/data/emnist_mapping.py
parentbd4bd443f339e95007bfdabf3e060db720f4d4b9 (diff)
Fix bugs in converting text in mappings, add missing word_piece arg in datamodule
Diffstat (limited to 'text_recognizer/data/emnist_mapping.py')
-rw-r--r--text_recognizer/data/emnist_mapping.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/text_recognizer/data/emnist_mapping.py b/text_recognizer/data/emnist_mapping.py
index 6c4c43b..925d214 100644
--- a/text_recognizer/data/emnist_mapping.py
+++ b/text_recognizer/data/emnist_mapping.py
@@ -1,6 +1,7 @@
"""Emnist mapping."""
from typing import List, Optional, Union, Set
+import torch
from torch import Tensor
from text_recognizer.data.base_mapping import AbstractMapping
@@ -19,13 +20,13 @@ class EmnistMapping(AbstractMapping):
"""Post init configuration."""
def get_token(self, index: Union[int, Tensor]) -> str:
- if (index := int(index)) in self.mapping:
+ if (index := int(index)) <= len(self.mapping):
return self.mapping[index]
raise KeyError(f"Index ({index}) not in mapping.")
def get_index(self, token: str) -> Tensor:
if token in self.inverse_mapping:
- return Tensor(self.inverse_mapping[token])
+ return torch.LongTensor([self.inverse_mapping[token]])
raise KeyError(f"Token ({token}) not found in inverse mapping.")
def get_text(self, indices: Union[List[int], Tensor]) -> str:
@@ -35,3 +36,6 @@ class EmnistMapping(AbstractMapping):
def get_indices(self, text: str) -> Tensor:
return Tensor([self.inverse_mapping[token] for token in text])
+
+ def __getitem__(self, x: Union[int, Tensor]) -> str:
+ return self.get_token(x)