diff options
Diffstat (limited to 'text_recognizer/data')
-rw-r--r-- | text_recognizer/data/emnist_mapping.py | 8 | ||||
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 2 | ||||
-rw-r--r-- | text_recognizer/data/word_piece_mapping.py | 10 |
3 files changed, 11 insertions, 9 deletions
diff --git a/text_recognizer/data/emnist_mapping.py b/text_recognizer/data/emnist_mapping.py index 6c4c43b..925d214 100644 --- a/text_recognizer/data/emnist_mapping.py +++ b/text_recognizer/data/emnist_mapping.py @@ -1,6 +1,7 @@ """Emnist mapping.""" from typing import List, Optional, Union, Set +import torch from torch import Tensor from text_recognizer.data.base_mapping import AbstractMapping @@ -19,13 +20,13 @@ class EmnistMapping(AbstractMapping): """Post init configuration.""" def get_token(self, index: Union[int, Tensor]) -> str: - if (index := int(index)) in self.mapping: + if (index := int(index)) <= len(self.mapping): return self.mapping[index] raise KeyError(f"Index ({index}) not in mapping.") def get_index(self, token: str) -> Tensor: if token in self.inverse_mapping: - return Tensor(self.inverse_mapping[token]) + return torch.LongTensor([self.inverse_mapping[token]]) raise KeyError(f"Token ({token}) not found in inverse mapping.") def get_text(self, indices: Union[List[int], Tensor]) -> str: @@ -35,3 +36,6 @@ class EmnistMapping(AbstractMapping): def get_indices(self, text: str) -> Tensor: return Tensor([self.inverse_mapping[token] for token in text]) + + def __getitem__(self, x: Union[int, Tensor]) -> str: + return self.get_token(x) diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 11f899f..3509b92 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -40,6 +40,8 @@ class IAMParagraphs(BaseDataModule): word_pieces: bool = attr.ib(default=False) augment: bool = attr.ib(default=True) train_fraction: float = attr.ib(default=0.8) + + # Placeholders dims: Tuple[int, int, int] = attr.ib( init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH) ) diff --git a/text_recognizer/data/word_piece_mapping.py b/text_recognizer/data/word_piece_mapping.py index 59488c3..2f650cd 100644 --- a/text_recognizer/data/word_piece_mapping.py +++ b/text_recognizer/data/word_piece_mapping.py @@ -75,7 +75,7 @@ class WordPieceMapping(EmnistMapping): def get_text(self, indices: Union[List[int], Tensor]) -> str: if isinstance(indices, Tensor): indices = indices.tolist() - return self.wordpiece_processor.to_text(indices).replace(" ", "▁") + return self.wordpiece_processor.to_text(indices) def get_indices(self, text: str) -> Tensor: return self.wordpiece_processor.to_index(text) @@ -85,9 +85,5 @@ class WordPieceMapping(EmnistMapping): text = text.lower().replace(" ", "▁") return torch.LongTensor(self.wordpiece_processor.to_index(text)) - def __getitem__(self, x: Union[str, int, List[int], Tensor]) -> Union[str, Tensor]: - if isinstance(x, int): - x = [x] - if isinstance(x, str): - return self.get_indices(x) - return self.get_text(x) + def __getitem__(self, x: Union[int, Tensor]) -> str: + return self.get_token(x) |