diff options
Diffstat (limited to 'text_recognizer/data/mappings')
-rw-r--r-- | text_recognizer/data/mappings/base_mapping.py | 37 | ||||
-rw-r--r-- | text_recognizer/data/mappings/emnist_essentials.json | 1 | ||||
-rw-r--r-- | text_recognizer/data/mappings/emnist_mapping.py | 60 | ||||
-rw-r--r-- | text_recognizer/data/mappings/word_piece_mapping.py | 98 |
4 files changed, 196 insertions, 0 deletions
diff --git a/text_recognizer/data/mappings/base_mapping.py b/text_recognizer/data/mappings/base_mapping.py new file mode 100644 index 0000000..572ac95 --- /dev/null +++ b/text_recognizer/data/mappings/base_mapping.py @@ -0,0 +1,37 @@ +"""Mapping to and from word pieces.""" +from abc import ABC, abstractmethod +from typing import Dict, List + +from torch import Tensor + + +class AbstractMapping(ABC): + def __init__( + self, input_size: List[int], mapping: List[str], inverse_mapping: Dict[str, int] + ) -> None: + self.input_size = input_size + self.mapping = mapping + self.inverse_mapping = inverse_mapping + + def __len__(self) -> int: + return len(self.mapping) + + @property + def num_classes(self) -> int: + return self.__len__() + + @abstractmethod + def get_token(self, *args, **kwargs) -> str: + ... + + @abstractmethod + def get_index(self, *args, **kwargs) -> Tensor: + ... + + @abstractmethod + def get_text(self, *args, **kwargs) -> str: + ... + + @abstractmethod + def get_indices(self, *args, **kwargs) -> Tensor: + ... diff --git a/text_recognizer/data/mappings/emnist_essentials.json b/text_recognizer/data/mappings/emnist_essentials.json new file mode 100644 index 0000000..c412425 --- /dev/null +++ b/text_recognizer/data/mappings/emnist_essentials.json @@ -0,0 +1 @@ +{"characters": ["<b>", "<s>", "<e>", "<p>", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]}
\ No newline at end of file diff --git a/text_recognizer/data/mappings/emnist_mapping.py b/text_recognizer/data/mappings/emnist_mapping.py new file mode 100644 index 0000000..3eed3d8 --- /dev/null +++ b/text_recognizer/data/mappings/emnist_mapping.py @@ -0,0 +1,60 @@ +"""Emnist mapping.""" +from typing import List, Optional, Set, Union + +import torch +from torch import Tensor + +from text_recognizer.data.mappings.base_mapping import AbstractMapping +from text_recognizer.data.emnist import emnist_mapping + + +class EmnistMapping(AbstractMapping): + """Mapping for EMNIST labels.""" + + def __init__( + self, extra_symbols: Optional[Set[str]] = None, lower: bool = True + ) -> None: + self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None + self.mapping, self.inverse_mapping, self.input_size = emnist_mapping( + self.extra_symbols + ) + if lower: + self._to_lower() + super().__init__(self.input_size, self.mapping, self.inverse_mapping) + + def _to_lower(self) -> None: + """Converts mapping to lowercase letters only.""" + + def _filter(x: int) -> int: + if 40 <= x: + return x - 26 + return x + + self.inverse_mapping = {v: _filter(k) for k, v in enumerate(self.mapping)} + self.mapping = [c for c in self.mapping if not c.isupper()] + + def get_token(self, index: Union[int, Tensor]) -> str: + """Returns token for index value.""" + if (index := int(index)) <= len(self.mapping): + return self.mapping[index] + raise KeyError(f"Index ({index}) not in mapping.") + + def get_index(self, token: str) -> Tensor: + """Returns index value of token.""" + if token in self.inverse_mapping: + return torch.LongTensor([self.inverse_mapping[token]]) + raise KeyError(f"Token ({token}) not found in inverse mapping.") + + def get_text(self, indices: Union[List[int], Tensor]) -> str: + """Returns the text from a list of indices.""" + if isinstance(indices, Tensor): + indices = indices.tolist() + return "".join([self.mapping[index] for index in indices]) + + def get_indices(self, text: str) -> Tensor: + """Returns tensor of indices for a string.""" + return Tensor([self.inverse_mapping[token] for token in text]) + + def __getitem__(self, x: Union[int, Tensor]) -> str: + """Returns text for a list of indices.""" + return self.get_token(x) diff --git a/text_recognizer/data/mappings/word_piece_mapping.py b/text_recognizer/data/mappings/word_piece_mapping.py new file mode 100644 index 0000000..6f1790e --- /dev/null +++ b/text_recognizer/data/mappings/word_piece_mapping.py @@ -0,0 +1,98 @@ +"""Word piece mapping.""" +from pathlib import Path +from typing import List, Optional, Set, Union + +from loguru import logger as log +import torch +from torch import Tensor + +from text_recognizer.data.mappings.emnist_mapping import EmnistMapping +from text_recognizer.data.utils.iam_preprocessor import Preprocessor + + +class WordPieceMapping(EmnistMapping): + """Word piece mapping.""" + + def __init__( + self, + data_dir: Optional[Path] = None, + num_features: int = 1000, + tokens: str = "iamdb_1kwp_tokens_1000.txt", + lexicon: str = "iamdb_1kwp_lex_1000.txt", + use_words: bool = False, + prepend_wordsep: bool = False, + special_tokens: Set[str] = {"<s>", "<e>", "<p>"}, + extra_symbols: Set[str] = {"\n"}, + ) -> None: + super().__init__(extra_symbols=extra_symbols) + self.data_dir = ( + ( + Path(__file__).resolve().parents[3] + / "data" + / "downloaded" + / "iam" + / "iamdb" + ) + if data_dir is None + else Path(data_dir) + ) + log.debug(f"Using data dir: {self.data_dir}") + if not self.data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {self.data_dir}") + + processed_path = ( + Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines" + ) + + tokens_path = processed_path / tokens + lexicon_path = processed_path / lexicon + + special_tokens = set(special_tokens) + if self.extra_symbols is not None: + special_tokens = special_tokens | set(extra_symbols) + + self.wordpiece_processor = Preprocessor( + data_dir=self.data_dir, + num_features=num_features, + tokens_path=tokens_path, + lexicon_path=lexicon_path, + use_words=use_words, + prepend_wordsep=prepend_wordsep, + special_tokens=special_tokens, + ) + + def __len__(self) -> int: + """Return number of word pieces.""" + return len(self.wordpiece_processor.tokens) + + def get_token(self, index: Union[int, Tensor]) -> str: + """Returns token for index.""" + if (index := int(index)) <= self.wordpiece_processor.num_tokens: + return self.wordpiece_processor.tokens[index] + raise KeyError(f"Index ({index}) not in mapping.") + + def get_index(self, token: str) -> Tensor: + """Returns index of token.""" + if token in self.wordpiece_processor.tokens: + return torch.LongTensor([self.wordpiece_processor.tokens_to_index[token]]) + raise KeyError(f"Token ({token}) not found in inverse mapping.") + + def get_text(self, indices: Union[List[int], Tensor]) -> str: + """Returns text from indices.""" + if isinstance(indices, Tensor): + indices = indices.tolist() + return self.wordpiece_processor.to_text(indices) + + def get_indices(self, text: str) -> Tensor: + """Returns indices of text.""" + return self.wordpiece_processor.to_index(text) + + def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor: + """Returns word pieces indices from emnist indices.""" + text = "".join([self.mapping[i] for i in x]) + text = text.lower().replace(" ", "▁") + return torch.LongTensor(self.wordpiece_processor.to_index(text)) + + def __getitem__(self, x: Union[int, Tensor]) -> str: + """Returns token for word piece index.""" + return self.get_token(x) |