diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 4 | ||||
-rw-r--r-- | text_recognizer/data/iam_preprocessor.py | 6 | ||||
-rw-r--r-- | text_recognizer/data/mappings.py | 10 | ||||
-rw-r--r-- | text_recognizer/data/transforms.py | 13 |
4 files changed, 27 insertions, 6 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 6022804..fe60e99 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -17,6 +17,7 @@ from text_recognizer.data.base_dataset import ( ) from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.data.mappings import WordPieceMapping from text_recognizer.data.iam import IAM from text_recognizer.data.transforms import WordPiece @@ -49,6 +50,9 @@ class IAMParagraphs(BaseDataModule): self.mapping, self.inverse_mapping, _ = emnist_mapping( extra_symbols=[NEW_LINE_TOKEN] ) + if word_pieces: + self.mapping = WordPieceMapping() + self.train_fraction = train_fraction self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py index b5f72da..506036e 100644 --- a/text_recognizer/data/iam_preprocessor.py +++ b/text_recognizer/data/iam_preprocessor.py @@ -89,9 +89,9 @@ class Preprocessor: self.lexicon = None if self.special_tokens is not None: - self.special_tokens += ("#", "*") - self.tokens += self.special_tokens - self.graphemes += self.special_tokens + special_tokens_ = (*self.special_tokens, "#", "*") + self.tokens += special_tokens_ + self.graphemes += special_tokens_ self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)} self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)} diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py index 190febe..0d778b2 100644 --- a/text_recognizer/data/mappings.py +++ b/text_recognizer/data/mappings.py @@ -125,6 +125,9 @@ class WordPieceMapping(EmnistMapping): special_tokens, ) + def __len__(self) -> int: + return len(self.wordpiece_processor.tokens) + def get_token(self, index: Union[int, Tensor]) -> str: if (index := int(index)) <= self.wordpiece_processor.num_tokens: return self.wordpiece_processor.tokens[index] @@ -132,7 +135,7 @@ class WordPieceMapping(EmnistMapping): def get_index(self, token: str) -> Tensor: if token in self.wordpiece_processor.tokens: - return torch.LongTensor(self.wordpiece_processor.tokens_to_index[token]) + return torch.LongTensor([self.wordpiece_processor.tokens_to_index[token]]) raise KeyError(f"Token ({token}) not found in inverse mapping.") def get_text(self, indices: Union[List[int], Tensor]) -> str: @@ -147,3 +150,8 @@ class WordPieceMapping(EmnistMapping): text = "".join([self.mapping[i] for i in x]) text = text.lower().replace(" ", "▁") return torch.LongTensor(self.wordpiece_processor.to_index(text)) + + def __getitem__(self, x: Union[str, int, Tensor]) -> Union[str, Tensor]: + if isinstance(x, str): + return self.get_index(x) + return self.get_token(x) diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index d0f1f35..66531a5 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -2,6 +2,7 @@ from pathlib import Path from typing import Optional, Union, Sequence +import torch from torch import Tensor from text_recognizer.data.mappings import WordPieceMapping @@ -20,7 +21,7 @@ class WordPiece: prepend_wordsep: bool = False, special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"), extra_symbols: Optional[Sequence[str]] = ("\n",), - max_len: int = 192, + max_len: int = 451, ) -> None: self.mapping = WordPieceMapping( num_features, @@ -35,4 +36,12 @@ class WordPiece: self.max_len = max_len def __call__(self, x: Tensor) -> Tensor: - return self.mapping.emnist_to_wordpiece_indices(x)[: self.max_len] + y = self.mapping.emnist_to_wordpiece_indices(x) + if len(y) < self.max_len: + pad_len = self.max_len - len(y) + y = torch.cat( + (y, torch.LongTensor([self.mapping.get_index("<p>")] * pad_len)) + ) + else: + y = y[: self.max_len] + return y |