summaryrefslogtreecommitdiff
path: root/text_recognizer/data/mappings
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/mappings')
-rw-r--r--text_recognizer/data/mappings/base_mapping.py37
-rw-r--r--text_recognizer/data/mappings/emnist_essentials.json1
-rw-r--r--text_recognizer/data/mappings/emnist_mapping.py60
-rw-r--r--text_recognizer/data/mappings/word_piece_mapping.py98
4 files changed, 196 insertions, 0 deletions
diff --git a/text_recognizer/data/mappings/base_mapping.py b/text_recognizer/data/mappings/base_mapping.py
new file mode 100644
index 0000000..572ac95
--- /dev/null
+++ b/text_recognizer/data/mappings/base_mapping.py
@@ -0,0 +1,37 @@
+"""Mapping to and from word pieces."""
+from abc import ABC, abstractmethod
+from typing import Dict, List
+
+from torch import Tensor
+
+
+class AbstractMapping(ABC):
+ def __init__(
+ self, input_size: List[int], mapping: List[str], inverse_mapping: Dict[str, int]
+ ) -> None:
+ self.input_size = input_size
+ self.mapping = mapping
+ self.inverse_mapping = inverse_mapping
+
+ def __len__(self) -> int:
+ return len(self.mapping)
+
+ @property
+ def num_classes(self) -> int:
+ return self.__len__()
+
+ @abstractmethod
+ def get_token(self, *args, **kwargs) -> str:
+ ...
+
+ @abstractmethod
+ def get_index(self, *args, **kwargs) -> Tensor:
+ ...
+
+ @abstractmethod
+ def get_text(self, *args, **kwargs) -> str:
+ ...
+
+ @abstractmethod
+ def get_indices(self, *args, **kwargs) -> Tensor:
+ ...
diff --git a/text_recognizer/data/mappings/emnist_essentials.json b/text_recognizer/data/mappings/emnist_essentials.json
new file mode 100644
index 0000000..c412425
--- /dev/null
+++ b/text_recognizer/data/mappings/emnist_essentials.json
@@ -0,0 +1 @@
+{"characters": ["<b>", "<s>", "<e>", "<p>", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} \ No newline at end of file
diff --git a/text_recognizer/data/mappings/emnist_mapping.py b/text_recognizer/data/mappings/emnist_mapping.py
new file mode 100644
index 0000000..3eed3d8
--- /dev/null
+++ b/text_recognizer/data/mappings/emnist_mapping.py
@@ -0,0 +1,60 @@
+"""Emnist mapping."""
+from typing import List, Optional, Set, Union
+
+import torch
+from torch import Tensor
+
+from text_recognizer.data.mappings.base_mapping import AbstractMapping
+from text_recognizer.data.emnist import emnist_mapping
+
+
+class EmnistMapping(AbstractMapping):
+ """Mapping for EMNIST labels."""
+
+ def __init__(
+ self, extra_symbols: Optional[Set[str]] = None, lower: bool = True
+ ) -> None:
+ self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None
+ self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(
+ self.extra_symbols
+ )
+ if lower:
+ self._to_lower()
+ super().__init__(self.input_size, self.mapping, self.inverse_mapping)
+
+ def _to_lower(self) -> None:
+ """Converts mapping to lowercase letters only."""
+
+ def _filter(x: int) -> int:
+ if 40 <= x:
+ return x - 26
+ return x
+
+ self.inverse_mapping = {v: _filter(k) for k, v in enumerate(self.mapping)}
+ self.mapping = [c for c in self.mapping if not c.isupper()]
+
+ def get_token(self, index: Union[int, Tensor]) -> str:
+ """Returns token for index value."""
+ if (index := int(index)) <= len(self.mapping):
+ return self.mapping[index]
+ raise KeyError(f"Index ({index}) not in mapping.")
+
+ def get_index(self, token: str) -> Tensor:
+ """Returns index value of token."""
+ if token in self.inverse_mapping:
+ return torch.LongTensor([self.inverse_mapping[token]])
+ raise KeyError(f"Token ({token}) not found in inverse mapping.")
+
+ def get_text(self, indices: Union[List[int], Tensor]) -> str:
+ """Returns the text from a list of indices."""
+ if isinstance(indices, Tensor):
+ indices = indices.tolist()
+ return "".join([self.mapping[index] for index in indices])
+
+ def get_indices(self, text: str) -> Tensor:
+ """Returns tensor of indices for a string."""
+ return Tensor([self.inverse_mapping[token] for token in text])
+
+ def __getitem__(self, x: Union[int, Tensor]) -> str:
+ """Returns text for a list of indices."""
+ return self.get_token(x)
diff --git a/text_recognizer/data/mappings/word_piece_mapping.py b/text_recognizer/data/mappings/word_piece_mapping.py
new file mode 100644
index 0000000..6f1790e
--- /dev/null
+++ b/text_recognizer/data/mappings/word_piece_mapping.py
@@ -0,0 +1,98 @@
+"""Word piece mapping."""
+from pathlib import Path
+from typing import List, Optional, Set, Union
+
+from loguru import logger as log
+import torch
+from torch import Tensor
+
+from text_recognizer.data.mappings.emnist_mapping import EmnistMapping
+from text_recognizer.data.utils.iam_preprocessor import Preprocessor
+
+
+class WordPieceMapping(EmnistMapping):
+ """Word piece mapping."""
+
+ def __init__(
+ self,
+ data_dir: Optional[Path] = None,
+ num_features: int = 1000,
+ tokens: str = "iamdb_1kwp_tokens_1000.txt",
+ lexicon: str = "iamdb_1kwp_lex_1000.txt",
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ special_tokens: Set[str] = {"<s>", "<e>", "<p>"},
+ extra_symbols: Set[str] = {"\n"},
+ ) -> None:
+ super().__init__(extra_symbols=extra_symbols)
+ self.data_dir = (
+ (
+ Path(__file__).resolve().parents[3]
+ / "data"
+ / "downloaded"
+ / "iam"
+ / "iamdb"
+ )
+ if data_dir is None
+ else Path(data_dir)
+ )
+ log.debug(f"Using data dir: {self.data_dir}")
+ if not self.data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {self.data_dir}")
+
+ processed_path = (
+ Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines"
+ )
+
+ tokens_path = processed_path / tokens
+ lexicon_path = processed_path / lexicon
+
+ special_tokens = set(special_tokens)
+ if self.extra_symbols is not None:
+ special_tokens = special_tokens | set(extra_symbols)
+
+ self.wordpiece_processor = Preprocessor(
+ data_dir=self.data_dir,
+ num_features=num_features,
+ tokens_path=tokens_path,
+ lexicon_path=lexicon_path,
+ use_words=use_words,
+ prepend_wordsep=prepend_wordsep,
+ special_tokens=special_tokens,
+ )
+
+ def __len__(self) -> int:
+ """Return number of word pieces."""
+ return len(self.wordpiece_processor.tokens)
+
+ def get_token(self, index: Union[int, Tensor]) -> str:
+ """Returns token for index."""
+ if (index := int(index)) <= self.wordpiece_processor.num_tokens:
+ return self.wordpiece_processor.tokens[index]
+ raise KeyError(f"Index ({index}) not in mapping.")
+
+ def get_index(self, token: str) -> Tensor:
+ """Returns index of token."""
+ if token in self.wordpiece_processor.tokens:
+ return torch.LongTensor([self.wordpiece_processor.tokens_to_index[token]])
+ raise KeyError(f"Token ({token}) not found in inverse mapping.")
+
+ def get_text(self, indices: Union[List[int], Tensor]) -> str:
+ """Returns text from indices."""
+ if isinstance(indices, Tensor):
+ indices = indices.tolist()
+ return self.wordpiece_processor.to_text(indices)
+
+ def get_indices(self, text: str) -> Tensor:
+ """Returns indices of text."""
+ return self.wordpiece_processor.to_index(text)
+
+ def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor:
+ """Returns word pieces indices from emnist indices."""
+ text = "".join([self.mapping[i] for i in x])
+ text = text.lower().replace(" ", "▁")
+ return torch.LongTensor(self.wordpiece_processor.to_index(text))
+
+ def __getitem__(self, x: Union[int, Tensor]) -> str:
+ """Returns token for word piece index."""
+ return self.get_token(x)