summaryrefslogtreecommitdiff
path: root/text_recognizer/data
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-08 23:38:03 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-08 23:38:03 +0200
commite388cd95c77d37a51324cff9d84a809421bf97d3 (patch)
treed585545f85d03ea8a6907daba254821fddeb1589 /text_recognizer/data
parentf4629a0d4149d5870c9fd8ce83ff5d391bd7ddd3 (diff)
Bug fixes word pieces
Diffstat (limited to 'text_recognizer/data')
-rw-r--r--text_recognizer/data/base_dataset.py2
-rw-r--r--text_recognizer/data/iam.py1
-rw-r--r--text_recognizer/data/iam_preprocessor.py28
-rw-r--r--text_recognizer/data/transforms.py6
4 files changed, 25 insertions, 12 deletions
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index d00daaf..8d644d4 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -67,7 +67,7 @@ def convert_strings_to_labels(
labels = torch.ones((len(strings), length), dtype=torch.long) * mapping["<p>"]
for i, string in enumerate(strings):
tokens = list(string)
- tokens = ["<s>", *tokens, "</s>"]
+ tokens = ["<s>", *tokens, "<e>"]
for j, token in enumerate(tokens):
labels[i, j] = mapping[token]
return labels
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index 01272ba..261c8d3 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -7,7 +7,6 @@ import zipfile
from boltons.cacheutils import cachedproperty
from loguru import logger
-from PIL import Image
import toml
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py
index 3844419..d85787e 100644
--- a/text_recognizer/data/iam_preprocessor.py
+++ b/text_recognizer/data/iam_preprocessor.py
@@ -47,8 +47,6 @@ def load_metadata(
class Preprocessor:
"""A preprocessor for the IAM dataset."""
- # TODO: add lower case only to when generating...
-
def __init__(
self,
data_dir: Union[str, Path],
@@ -57,10 +55,12 @@ class Preprocessor:
lexicon_path: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
+ special_tokens: Optional[List[str]] = None,
) -> None:
self.wordsep = "▁"
self._use_word = use_words
self._prepend_wordsep = prepend_wordsep
+ self.special_tokens = special_tokens if special_tokens is not None else None
self.data_dir = Path(data_dir)
@@ -88,6 +88,10 @@ class Preprocessor:
else:
self.lexicon = None
+ if self.special_tokens is not None:
+ self.tokens += self.special_tokens
+ self.graphemes += self.special_tokens
+
self.graphemes_to_index = {t: i for i, t in enumerate(self.graphemes)}
self.tokens_to_index = {t: i for i, t in enumerate(self.tokens)}
self.num_features = num_features
@@ -115,21 +119,31 @@ class Preprocessor:
continue
self.text.append(example["text"].lower())
- def to_index(self, line: str) -> torch.LongTensor:
- """Converts text to a tensor of indices."""
+
+ def _to_index(self, line: str) -> torch.LongTensor:
+ if line in self.special_tokens:
+ return torch.LongTensor([self.tokens_to_index[line]])
token_to_index = self.graphemes_to_index
if self.lexicon is not None:
if len(line) > 0:
# If the word is not found in the lexicon, fall back to letters.
- line = [
+ tokens = [
t
for w in line.split(self.wordsep)
for t in self.lexicon.get(w, self.wordsep + w)
]
token_to_index = self.tokens_to_index
if self._prepend_wordsep:
- line = itertools.chain([self.wordsep], line)
- return torch.LongTensor([token_to_index[t] for t in line])
+ tokens = itertools.chain([self.wordsep], tokens)
+ return torch.LongTensor([token_to_index[t] for t in tokens])
+
+ def to_index(self, line: str) -> torch.LongTensor:
+ """Converts text to a tensor of indices."""
+ if self.special_tokens is not None:
+ pattern = f"({'|'.join(self.special_tokens)})"
+ lines = list(filter(None, re.split(pattern, line)))
+ return torch.cat([self._to_index(l) for l in lines])
+ return self._to_index(line)
def to_text(self, indices: List[int]) -> str:
"""Converts indices to text."""
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index 616e236..297c953 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -23,12 +23,12 @@ class ToLower:
class ToCharcters:
"""Converts integers to characters."""
- def __init__(self) -> None:
- self.mapping, _, _ = emnist_mapping()
+ def __init__(self, extra_symbols: Optional[List[str]] = None) -> None:
+ self.mapping, _, _ = emnist_mapping(extra_symbols)
def __call__(self, y: Tensor) -> str:
"""Converts a Tensor to a str."""
- return "".join([self.mapping(int(i)) for i in y]).strip("<p>").replace(" ", "▁")
+ return "".join([self.mapping[int(i)] for i in y]).replace(" ", "▁")
class WordPieces: