summaryrefslogtreecommitdiff
path: root/text_recognizer/data/mappings.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/mappings.py')
-rw-r--r--text_recognizer/data/mappings.py143
1 files changed, 143 insertions, 0 deletions
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py
new file mode 100644
index 0000000..cfa0ec7
--- /dev/null
+++ b/text_recognizer/data/mappings.py
@@ -0,0 +1,143 @@
+"""Mapping to and from word pieces."""
+from abc import ABC, abstractmethod
+from pathlib import Path
+from typing import List, Optional, Union, Sequence
+
+from loguru import logger
+import torch
+from torch import Tensor
+
+from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.datasets.iam_preprocessor import Preprocessor
+
+
+class AbstractMapping(ABC):
+ @abstractmethod
+ def get_token(self, *args, **kwargs) -> str:
+ ...
+
+ @abstractmethod
+ def get_index(self, *args, **kwargs) -> Tensor:
+ ...
+
+ @abstractmethod
+ def get_text(self, *args, **kwargs) -> str:
+ ...
+
+ @abstractmethod
+ def get_indices(self, *args, **kwargs) -> Tensor:
+ ...
+
+
+class EmnistMapping(AbstractMapping):
+ def __init__(self, extra_symbols: Optional[Sequence[str]]) -> None:
+ self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(
+ extra_symbols
+ )
+
+ def get_token(self, index: Union[int, Tensor]) -> str:
+ if (index := int(index)) in self.mapping:
+ return self.mapping[index]
+ raise KeyError(f"Index ({index}) not in mapping.")
+
+ def get_index(self, token: str) -> Tensor:
+ if token in self.inverse_mapping:
+ return Tensor(self.inverse_mapping[token])
+ raise KeyError(f"Token ({token}) not found in inverse mapping.")
+
+ def get_text(self, indices: Union[List[int], Tensor]) -> str:
+ if isinstance(indices, Tensor):
+ indices = indices.tolist()
+ return "".join([self.mapping[index] for index in indices])
+
+ def get_indices(self, text: str) -> Tensor:
+ return Tensor([self.inverse_mapping[token] for token in text])
+
+
+class WordPieceMapping(EmnistMapping):
+ def __init__(
+ self,
+ num_features: int,
+ tokens: str,
+ lexicon: str,
+ data_dir: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
+ extra_symbols: Optional[Sequence[str]] = None,
+ ) -> None:
+ super().__init__(extra_symbols)
+ self.wordpiece_processor = self._configure_wordpiece_processor(
+ num_features,
+ tokens,
+ lexicon,
+ data_dir,
+ use_words,
+ prepend_wordsep,
+ special_tokens,
+ extra_symbols,
+ )
+
+ def _configure_wordpiece_processor(
+ self,
+ num_features: int,
+ tokens: str,
+ lexicon: str,
+ data_dir: Optional[Union[str, Path]],
+ use_words: bool,
+ prepend_wordsep: bool,
+ special_tokens: Optional[Sequence[str]],
+ extra_symbols: Optional[Sequence[str]],
+ ) -> Preprocessor:
+ data_dir = (
+ (Path(__file__).resolve().parents[2] / "data" / "raw" / "iam" / "iamdb")
+ if data_dir is None
+ else Path(data_dir)
+ )
+
+ logger.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+
+ processed_path = (
+ Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
+ )
+
+ tokens_path = processed_path / tokens
+ lexicon_path = processed_path / lexicon
+
+ if extra_symbols is not None:
+ special_tokens += extra_symbols
+
+ return Preprocessor(
+ data_dir,
+ num_features,
+ tokens_path,
+ lexicon_path,
+ use_words,
+ prepend_wordsep,
+ special_tokens,
+ )
+
+ def get_token(self, index: Union[int, Tensor]) -> str:
+ if (index := int(index)) <= self.wordpiece_processor.num_tokens:
+ return self.wordpiece_processor.tokens[index]
+ raise KeyError(f"Index ({index}) not in mapping.")
+
+ def get_index(self, token: str) -> Tensor:
+ if token in self.wordpiece_processor.tokens:
+ return torch.LongTensor(self.wordpiece_processor.tokens_to_index[token])
+ raise KeyError(f"Token ({token}) not found in inverse mapping.")
+
+ def get_text(self, indices: Union[List[int], Tensor]) -> str:
+ if isinstance(indices, Tensor):
+ indices = indices.tolist()
+ return self.wordpiece_processor.to_text(indices)
+
+ def get_indices(self, text: str) -> Tensor:
+ return self.wordpiece_processor.to_index(text)
+
+ def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor:
+ text = self.mapping.get_text(x)
+ text = text.lower().replace(" ", "▁")
+ return torch.LongTensor(self.wordpiece_processor.to_index(text))