diff options
Diffstat (limited to 'text_recognizer/data/iam_preprocessor.py')
-rw-r--r-- | text_recognizer/data/iam_preprocessor.py | 8 |
1 files changed, 3 insertions, 5 deletions
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py index f7457e4..93a13bb 100644 --- a/text_recognizer/data/iam_preprocessor.py +++ b/text_recognizer/data/iam_preprocessor.py @@ -9,7 +9,7 @@ import collections import itertools from pathlib import Path import re -from typing import List, Optional, Union +from typing import List, Optional, Union, Sequence import click from loguru import logger @@ -57,15 +57,13 @@ class Preprocessor: lexicon_path: Optional[Union[str, Path]] = None, use_words: bool = False, prepend_wordsep: bool = False, - special_tokens: Optional[List[str]] = None, + special_tokens: Optional[Sequence[str]] = None, ) -> None: self.wordsep = "▁" self._use_word = use_words self._prepend_wordsep = prepend_wordsep self.special_tokens = special_tokens if special_tokens is not None else None - self.data_dir = Path(data_dir) - self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words) # Load the set of graphemes: @@ -123,7 +121,7 @@ class Preprocessor: self.text.append(example["text"].lower()) def _to_index(self, line: str) -> torch.LongTensor: - if line in self.special_tokens: + if self.special_tokens is not None and line in self.special_tokens: return torch.LongTensor([self.tokens_to_index[line]]) token_to_index = self.graphemes_to_index if self.lexicon is not None: |