diff options
Diffstat (limited to 'text_recognizer/data')
| -rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 4 | ||||
| -rw-r--r-- | text_recognizer/data/iam_preprocessor.py | 6 | ||||
| -rw-r--r-- | text_recognizer/data/mappings.py | 10 | ||||
| -rw-r--r-- | text_recognizer/data/transforms.py | 13 | 
4 files changed, 27 insertions, 6 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 6022804..fe60e99 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -17,6 +17,7 @@ from text_recognizer.data.base_dataset import (  )  from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info  from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.data.mappings import WordPieceMapping  from text_recognizer.data.iam import IAM  from text_recognizer.data.transforms import WordPiece @@ -49,6 +50,9 @@ class IAMParagraphs(BaseDataModule):          self.mapping, self.inverse_mapping, _ = emnist_mapping(              extra_symbols=[NEW_LINE_TOKEN]          ) +        if word_pieces: +            self.mapping = WordPieceMapping() +          self.train_fraction = train_fraction          self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py index b5f72da..506036e 100644 --- a/text_recognizer/data/iam_preprocessor.py +++ b/text_recognizer/data/iam_preprocessor.py @@ -89,9 +89,9 @@ class Preprocessor:              self.lexicon = None          if self.special_tokens is not None: -            self.special_tokens += ("#", "*") -            self.tokens += self.special_tokens -            self.graphemes += self.special_tokens +            special_tokens_ = (*self.special_tokens, "#", "*") +            self.tokens += special_tokens_ +            self.graphemes += special_tokens_          self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)}          self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)} diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py index 190febe..0d778b2 100644 --- a/text_recognizer/data/mappings.py +++ b/text_recognizer/data/mappings.py @@ -125,6 +125,9 @@ class WordPieceMapping(EmnistMapping):              special_tokens,          ) +    def __len__(self) -> int: +        return len(self.wordpiece_processor.tokens) +      def get_token(self, index: Union[int, Tensor]) -> str:          if (index := int(index)) <= self.wordpiece_processor.num_tokens:              return self.wordpiece_processor.tokens[index] @@ -132,7 +135,7 @@ class WordPieceMapping(EmnistMapping):      def get_index(self, token: str) -> Tensor:          if token in self.wordpiece_processor.tokens: -            return torch.LongTensor(self.wordpiece_processor.tokens_to_index[token]) +            return torch.LongTensor([self.wordpiece_processor.tokens_to_index[token]])          raise KeyError(f"Token ({token}) not found in inverse mapping.")      def get_text(self, indices: Union[List[int], Tensor]) -> str: @@ -147,3 +150,8 @@ class WordPieceMapping(EmnistMapping):          text = "".join([self.mapping[i] for i in x])          text = text.lower().replace(" ", "▁")          return torch.LongTensor(self.wordpiece_processor.to_index(text)) + +    def __getitem__(self, x: Union[str, int, Tensor]) -> Union[str, Tensor]: +        if isinstance(x, str): +            return self.get_index(x) +        return self.get_token(x) diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index d0f1f35..66531a5 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -2,6 +2,7 @@  from pathlib import Path  from typing import Optional, Union, Sequence +import torch  from torch import Tensor  from text_recognizer.data.mappings import WordPieceMapping @@ -20,7 +21,7 @@ class WordPiece:          prepend_wordsep: bool = False,          special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),          extra_symbols: Optional[Sequence[str]] = ("\n",), -        max_len: int = 192, +        max_len: int = 451,      ) -> None:          self.mapping = WordPieceMapping(              num_features, @@ -35,4 +36,12 @@ class WordPiece:          self.max_len = max_len      def __call__(self, x: Tensor) -> Tensor: -        return self.mapping.emnist_to_wordpiece_indices(x)[: self.max_len] +        y = self.mapping.emnist_to_wordpiece_indices(x) +        if len(y) < self.max_len: +            pad_len = self.max_len - len(y) +            y = torch.cat( +                (y, torch.LongTensor([self.mapping.get_index("<p>")] * pad_len)) +            ) +        else: +            y = y[: self.max_len] +        return y  |