summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_preprocessor.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam_preprocessor.py')
-rw-r--r--text_recognizer/data/iam_preprocessor.py28
1 files changed, 21 insertions, 7 deletions
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py
index 3844419..d85787e 100644
--- a/text_recognizer/data/iam_preprocessor.py
+++ b/text_recognizer/data/iam_preprocessor.py
@@ -47,8 +47,6 @@ def load_metadata(
class Preprocessor:
"""A preprocessor for the IAM dataset."""
- # TODO: add lower case only to when generating...
-
def __init__(
self,
data_dir: Union[str, Path],
@@ -57,10 +55,12 @@ class Preprocessor:
lexicon_path: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
+ special_tokens: Optional[List[str]] = None,
) -> None:
self.wordsep = "▁"
self._use_word = use_words
self._prepend_wordsep = prepend_wordsep
+ self.special_tokens = special_tokens if special_tokens is not None else None
self.data_dir = Path(data_dir)
@@ -88,6 +88,10 @@ class Preprocessor:
else:
self.lexicon = None
+ if self.special_tokens is not None:
+ self.tokens += self.special_tokens
+ self.graphemes += self.special_tokens
+
self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)}
self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)}
self.num_features = num_features
@@ -115,21 +119,31 @@ class Preprocessor:
continue
self.text.append(example["text"].lower())
- def to_index(self, line: str) -> torch.LongTensor:
- """Converts text to a tensor of indices."""
+
+ def _to_index(self, line: str) -> torch.LongTensor:
+ if line in self.special_tokens:
+ return torch.LongTensor([self.tokens_to_index[line]])
token_to_index = self.graphemes_to_index
if self.lexicon is not None:
if len(line) > 0:
# If the word is not found in the lexicon, fall back to letters.
- line = [
+ tokens = [
t
for w in line.split(self.wordsep)
for t in self.lexicon.get(w, self.wordsep + w)
]
token_to_index = self.tokens_to_index
if self._prepend_wordsep:
- line = itertools.chain([self.wordsep], line)
- return torch.LongTensor([token_to_index[t] for t in line])
+ tokens = itertools.chain([self.wordsep], tokens)
+ return torch.LongTensor([token_to_index[t] for t in tokens])
+
+ def to_index(self, line: str) -> torch.LongTensor:
+ """Converts text to a tensor of indices."""
+ if self.special_tokens is not None:
+ pattern = f"({'|'.join(self.special_tokens)})"
+ lines = list(filter(None, re.split(pattern, line)))
+ return torch.cat([self._to_index(l) for l in lines])
+ return self._to_index(line)
def to_text(self, indices: List[int]) -> str:
"""Converts indices to text."""