From 65d5f6c694e73792e40ed693a1381a792da8d277 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Tue, 3 Aug 2021 19:14:16 +0200
Subject: Fix bugs in converting text in mappings, add missing word_piece arg
 in datamodule

---
 text_recognizer/data/emnist_mapping.py     |  8 ++++++--
 text_recognizer/data/iam_paragraphs.py     |  2 ++
 text_recognizer/data/word_piece_mapping.py | 10 +++-------
 3 files changed, 11 insertions(+), 9 deletions(-)

(limited to 'text_recognizer/data')

diff --git a/text_recognizer/data/emnist_mapping.py b/text_recognizer/data/emnist_mapping.py
index 6c4c43b..925d214 100644
--- a/text_recognizer/data/emnist_mapping.py
+++ b/text_recognizer/data/emnist_mapping.py
@@ -1,6 +1,7 @@
 """Emnist mapping."""
 from typing import List, Optional, Union, Set
 
+import torch
 from torch import Tensor
 
 from text_recognizer.data.base_mapping import AbstractMapping
@@ -19,13 +20,13 @@ class EmnistMapping(AbstractMapping):
         """Post init configuration."""
 
     def get_token(self, index: Union[int, Tensor]) -> str:
-        if (index := int(index)) in self.mapping:
+        if (index := int(index)) <= len(self.mapping):
             return self.mapping[index]
         raise KeyError(f"Index ({index}) not in mapping.")
 
     def get_index(self, token: str) -> Tensor:
         if token in self.inverse_mapping:
-            return Tensor(self.inverse_mapping[token])
+            return torch.LongTensor([self.inverse_mapping[token]])
         raise KeyError(f"Token ({token}) not found in inverse mapping.")
 
     def get_text(self, indices: Union[List[int], Tensor]) -> str:
@@ -35,3 +36,6 @@ class EmnistMapping(AbstractMapping):
 
     def get_indices(self, text: str) -> Tensor:
         return Tensor([self.inverse_mapping[token] for token in text])
+
+    def __getitem__(self, x: Union[int, Tensor]) -> str:
+        return self.get_token(x)
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 11f899f..3509b92 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -40,6 +40,8 @@ class IAMParagraphs(BaseDataModule):
     word_pieces: bool = attr.ib(default=False)
     augment: bool = attr.ib(default=True)
     train_fraction: float = attr.ib(default=0.8)
+
+    # Placeholders
     dims: Tuple[int, int, int] = attr.ib(
         init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)
     )
diff --git a/text_recognizer/data/word_piece_mapping.py b/text_recognizer/data/word_piece_mapping.py
index 59488c3..2f650cd 100644
--- a/text_recognizer/data/word_piece_mapping.py
+++ b/text_recognizer/data/word_piece_mapping.py
@@ -75,7 +75,7 @@ class WordPieceMapping(EmnistMapping):
     def get_text(self, indices: Union[List[int], Tensor]) -> str:
         if isinstance(indices, Tensor):
             indices = indices.tolist()
-        return self.wordpiece_processor.to_text(indices).replace(" ", "▁")
+        return self.wordpiece_processor.to_text(indices)
 
     def get_indices(self, text: str) -> Tensor:
         return self.wordpiece_processor.to_index(text)
@@ -85,9 +85,5 @@ class WordPieceMapping(EmnistMapping):
         text = text.lower().replace(" ", "▁")
         return torch.LongTensor(self.wordpiece_processor.to_index(text))
 
-    def __getitem__(self, x: Union[str, int, List[int], Tensor]) -> Union[str, Tensor]:
-        if isinstance(x, int):
-            x = [x]
-        if isinstance(x, str):
-            return self.get_indices(x)
-        return self.get_text(x)
+    def __getitem__(self, x: Union[int, Tensor]) -> str:
+        return self.get_token(x)
-- 
cgit v1.2.3-70-g09d2