summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_preprocessor.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam_preprocessor.py')
-rw-r--r--text_recognizer/data/iam_preprocessor.py8
1 files changed, 3 insertions, 5 deletions
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py
index f7457e4..93a13bb 100644
--- a/text_recognizer/data/iam_preprocessor.py
+++ b/text_recognizer/data/iam_preprocessor.py
@@ -9,7 +9,7 @@ import collections
import itertools
from pathlib import Path
import re
-from typing import List, Optional, Union
+from typing import List, Optional, Union, Sequence
import click
from loguru import logger
@@ -57,15 +57,13 @@ class Preprocessor:
lexicon_path: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
- special_tokens: Optional[List[str]] = None,
+ special_tokens: Optional[Sequence[str]] = None,
) -> None:
self.wordsep = "▁"
self._use_word = use_words
self._prepend_wordsep = prepend_wordsep
self.special_tokens = special_tokens if special_tokens is not None else None
-
self.data_dir = Path(data_dir)
-
self.forms = load_metadata(self.data_dir, self.wordsep, use_words=use_words)
# Load the set of graphemes:
@@ -123,7 +121,7 @@ class Preprocessor:
self.text.append(example["text"].lower())
def _to_index(self, line: str) -> torch.LongTensor:
- if line in self.special_tokens:
+ if self.special_tokens is not None and line in self.special_tokens:
return torch.LongTensor([self.tokens_to_index[line]])
token_to_index = self.graphemes_to_index
if self.lexicon is not None: