diff options
Diffstat (limited to 'text_recognizer/data/iam_preprocessor.py')
-rw-r--r-- | text_recognizer/data/iam_preprocessor.py | 28 |
1 files changed, 21 insertions, 7 deletions
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py index 3844419..d85787e 100644 --- a/text_recognizer/data/iam_preprocessor.py +++ b/text_recognizer/data/iam_preprocessor.py @@ -47,8 +47,6 @@ def load_metadata( class Preprocessor: """A preprocessor for the IAM dataset.""" - # TODO: add lower case only to when generating... - def __init__( self, data_dir: Union[str, Path], @@ -57,10 +55,12 @@ class Preprocessor: lexicon_path: Optional[Union[str, Path]] = None, use_words: bool = False, prepend_wordsep: bool = False, + special_tokens: Optional[List[str]] = None, ) -> None: self.wordsep = "▁" self._use_word = use_words self._prepend_wordsep = prepend_wordsep + self.special_tokens = special_tokens if special_tokens is not None else None self.data_dir = Path(data_dir) @@ -88,6 +88,10 @@ class Preprocessor: else: self.lexicon = None + if self.special_tokens is not None: + self.tokens += self.special_tokens + self.graphemes += self.special_tokens + self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)} self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)} self.num_features = num_features @@ -115,21 +119,31 @@ class Preprocessor: continue self.text.append(example["text"].lower()) - def to_index(self, line: str) -> torch.LongTensor: - """Converts text to a tensor of indices.""" + + def _to_index(self, line: str) -> torch.LongTensor: + if line in self.special_tokens: + return torch.LongTensor([self.tokens_to_index[line]]) token_to_index = self.graphemes_to_index if self.lexicon is not None: if len(line) > 0: # If the word is not found in the lexicon, fall back to letters. - line = [ + tokens = [ t for w in line.split(self.wordsep) for t in self.lexicon.get(w, self.wordsep + w) ] token_to_index = self.tokens_to_index if self._prepend_wordsep: - line = itertools.chain([self.wordsep], line) - return torch.LongTensor([token_to_index[t] for t in line]) + tokens = itertools.chain([self.wordsep], tokens) + return torch.LongTensor([token_to_index[t] for t in tokens]) + + def to_index(self, line: str) -> torch.LongTensor: + """Converts text to a tensor of indices.""" + if self.special_tokens is not None: + pattern = f"({'|'.join(self.special_tokens)})" + lines = list(filter(None, re.split(pattern, line))) + return torch.cat([self._to_index(l) for l in lines]) + return self._to_index(line) def to_text(self, indices: List[int]) -> str: """Converts indices to text.""" |