summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/data/iam_paragraphs.py4
-rw-r--r--text_recognizer/data/iam_preprocessor.py6
-rw-r--r--text_recognizer/data/mappings.py10
-rw-r--r--text_recognizer/data/transforms.py13
4 files changed, 27 insertions, 6 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 6022804..fe60e99 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -17,6 +17,7 @@ from text_recognizer.data.base_dataset import (
)
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.data.mappings import WordPieceMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data.transforms import WordPiece
@@ -49,6 +50,9 @@ class IAMParagraphs(BaseDataModule):
self.mapping, self.inverse_mapping, _ = emnist_mapping(
extra_symbols=[NEW_LINE_TOKEN]
)
+ if word_pieces:
+ self.mapping = WordPieceMapping()
+
self.train_fraction = train_fraction
self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py
index b5f72da..506036e 100644
--- a/text_recognizer/data/iam_preprocessor.py
+++ b/text_recognizer/data/iam_preprocessor.py
@@ -89,9 +89,9 @@ class Preprocessor:
self.lexicon = None
if self.special_tokens is not None:
- self.special_tokens += ("#", "*")
- self.tokens += self.special_tokens
- self.graphemes += self.special_tokens
+ special_tokens_ = (*self.special_tokens, "#", "*")
+ self.tokens += special_tokens_
+ self.graphemes += special_tokens_
self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)}
self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)}
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py
index 190febe..0d778b2 100644
--- a/text_recognizer/data/mappings.py
+++ b/text_recognizer/data/mappings.py
@@ -125,6 +125,9 @@ class WordPieceMapping(EmnistMapping):
special_tokens,
)
+ def __len__(self) -> int:
+ return len(self.wordpiece_processor.tokens)
+
def get_token(self, index: Union[int, Tensor]) -> str:
if (index := int(index)) <= self.wordpiece_processor.num_tokens:
return self.wordpiece_processor.tokens[index]
@@ -132,7 +135,7 @@ class WordPieceMapping(EmnistMapping):
def get_index(self, token: str) -> Tensor:
if token in self.wordpiece_processor.tokens:
- return torch.LongTensor(self.wordpiece_processor.tokens_to_index[token])
+ return torch.LongTensor([self.wordpiece_processor.tokens_to_index[token]])
raise KeyError(f"Token ({token}) not found in inverse mapping.")
def get_text(self, indices: Union[List[int], Tensor]) -> str:
@@ -147,3 +150,8 @@ class WordPieceMapping(EmnistMapping):
text = "".join([self.mapping[i] for i in x])
text = text.lower().replace(" ", "▁")
return torch.LongTensor(self.wordpiece_processor.to_index(text))
+
+ def __getitem__(self, x: Union[str, int, Tensor]) -> Union[str, Tensor]:
+ if isinstance(x, str):
+ return self.get_index(x)
+ return self.get_token(x)
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index d0f1f35..66531a5 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -2,6 +2,7 @@
from pathlib import Path
from typing import Optional, Union, Sequence
+import torch
from torch import Tensor
from text_recognizer.data.mappings import WordPieceMapping
@@ -20,7 +21,7 @@ class WordPiece:
prepend_wordsep: bool = False,
special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
extra_symbols: Optional[Sequence[str]] = ("\n",),
- max_len: int = 192,
+ max_len: int = 451,
) -> None:
self.mapping = WordPieceMapping(
num_features,
@@ -35,4 +36,12 @@ class WordPiece:
self.max_len = max_len
def __call__(self, x: Tensor) -> Tensor:
- return self.mapping.emnist_to_wordpiece_indices(x)[: self.max_len]
+ y = self.mapping.emnist_to_wordpiece_indices(x)
+ if len(y) < self.max_len:
+ pad_len = self.max_len - len(y)
+ y = torch.cat(
+ (y, torch.LongTensor([self.mapping.get_index("<p>")] * pad_len))
+ )
+ else:
+ y = y[: self.max_len]
+ return y