diff options
Diffstat (limited to 'text_recognizer/data')
-rw-r--r-- | text_recognizer/data/base_dataset.py | 2 | ||||
-rw-r--r-- | text_recognizer/data/iam.py | 1 | ||||
-rw-r--r-- | text_recognizer/data/iam_preprocessor.py | 28 | ||||
-rw-r--r-- | text_recognizer/data/transforms.py | 6 |
4 files changed, 25 insertions, 12 deletions
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index d00daaf..8d644d4 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -67,7 +67,7 @@ def convert_strings_to_labels( labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["<p>"] for i, string in enumerate(strings): tokens = list(string) - tokens = ["<s>", *tokens, "</s>"] + tokens = ["<s>", *tokens, "<e>"] for j, token in enumerate(tokens): labels[i, j] = mapping[token] return labels diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index 01272ba..261c8d3 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -7,7 +7,6 @@ import zipfile from boltons.cacheutils import cachedproperty from loguru import logger -from PIL import Image import toml from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py index 3844419..d85787e 100644 --- a/text_recognizer/data/iam_preprocessor.py +++ b/text_recognizer/data/iam_preprocessor.py @@ -47,8 +47,6 @@ def load_metadata( class Preprocessor: """A preprocessor for the IAM dataset.""" - # TODO: add lower case only to when generating... - def __init__( self, data_dir: Union[str, Path], @@ -57,10 +55,12 @@ class Preprocessor: lexicon_path: Optional[Union[str, Path]] = None, use_words: bool = False, prepend_wordsep: bool = False, + special_tokens: Optional[List[str]] = None, ) -> None: self.wordsep = "▁" self._use_word = use_words self._prepend_wordsep = prepend_wordsep + self.special_tokens = special_tokens if special_tokens is not None else None self.data_dir = Path(data_dir) @@ -88,6 +88,10 @@ class Preprocessor: else: self.lexicon = None + if self.special_tokens is not None: + self.tokens += self.special_tokens + self.graphemes += self.special_tokens + self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)} self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)} self.num_features = num_features @@ -115,21 +119,31 @@ class Preprocessor: continue self.text.append(example["text"].lower()) - def to_index(self, line: str) -> torch.LongTensor: - """Converts text to a tensor of indices.""" + + def _to_index(self, line: str) -> torch.LongTensor: + if line in self.special_tokens: + return torch.LongTensor([self.tokens_to_index[line]]) token_to_index = self.graphemes_to_index if self.lexicon is not None: if len(line) > 0: # If the word is not found in the lexicon, fall back to letters. - line = [ + tokens = [ t for w in line.split(self.wordsep) for t in self.lexicon.get(w, self.wordsep + w) ] token_to_index = self.tokens_to_index if self._prepend_wordsep: - line = itertools.chain([self.wordsep], line) - return torch.LongTensor([token_to_index[t] for t in line]) + tokens = itertools.chain([self.wordsep], tokens) + return torch.LongTensor([token_to_index[t] for t in tokens]) + + def to_index(self, line: str) -> torch.LongTensor: + """Converts text to a tensor of indices.""" + if self.special_tokens is not None: + pattern = f"({'|'.join(self.special_tokens)})" + lines = list(filter(None, re.split(pattern, line))) + return torch.cat([self._to_index(l) for l in lines]) + return self._to_index(line) def to_text(self, indices: List[int]) -> str: """Converts indices to text.""" diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index 616e236..297c953 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -23,12 +23,12 @@ class ToLower: class ToCharcters: """Converts integers to characters.""" - def __init__(self) -> None: - self.mapping, _, _ = emnist_mapping() + def __init__(self, extra_symbols: Optional[List[str]] = None) -> None: + self.mapping, _, _ = emnist_mapping(extra_symbols) def __call__(self, y: Tensor) -> str: """Converts a Tensor to a str.""" - return "".join([self.mapping(int(i)) for i in y]).strip("<p>").replace(" ", "▁") + return "".join([self.mapping[int(i)] for i in y]).replace(" ", "▁") class WordPieces: |