summaryrefslogtreecommitdiff
path: root/text_recognizer/data/mappings.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/mappings.py')
-rw-r--r--text_recognizer/data/mappings.py111
1 files changed, 50 insertions, 61 deletions
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py
index 0d778b2..a934fd9 100644
--- a/text_recognizer/data/mappings.py
+++ b/text_recognizer/data/mappings.py
@@ -1,8 +1,9 @@
"""Mapping to and from word pieces."""
from abc import ABC, abstractmethod
from pathlib import Path
-from typing import List, Optional, Union, Sequence
+from typing import Dict, List, Optional, Union, Set, Sequence
+import attr
from loguru import logger
import torch
from torch import Tensor
@@ -29,10 +30,17 @@ class AbstractMapping(ABC):
...
+@attr.s
class EmnistMapping(AbstractMapping):
- def __init__(self, extra_symbols: Optional[Sequence[str]]) -> None:
+ extra_symbols: Optional[Set[str]] = attr.ib(default=None, converter=set)
+ mapping: Sequence[str] = attr.ib(init=False)
+ inverse_mapping: Dict[str, int] = attr.ib(init=False)
+ input_size: List[int] = attr.ib(init=False)
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(
- extra_symbols
+ self.extra_symbols
)
def get_token(self, index: Union[int, Tensor]) -> str:
@@ -54,42 +62,21 @@ class EmnistMapping(AbstractMapping):
return Tensor([self.inverse_mapping[token] for token in text])
+@attr.s(auto_attribs=True)
class WordPieceMapping(EmnistMapping):
- def __init__(
- self,
- num_features: int = 1000,
- tokens: str = "iamdb_1kwp_tokens_1000.txt",
- lexicon: str = "iamdb_1kwp_lex_1000.txt",
- data_dir: Optional[Union[str, Path]] = None,
- use_words: bool = False,
- prepend_wordsep: bool = False,
- special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
- extra_symbols: Optional[Sequence[str]] = ("\n",),
- ) -> None:
- super().__init__(extra_symbols)
- self.wordpiece_processor = self._configure_wordpiece_processor(
- num_features,
- tokens,
- lexicon,
- data_dir,
- use_words,
- prepend_wordsep,
- special_tokens,
- extra_symbols,
- )
-
- @staticmethod
- def _configure_wordpiece_processor(
- num_features: int,
- tokens: str,
- lexicon: str,
- data_dir: Optional[Union[str, Path]],
- use_words: bool,
- prepend_wordsep: bool,
- special_tokens: Optional[Sequence[str]],
- extra_symbols: Optional[Sequence[str]],
- ) -> Preprocessor:
- data_dir = (
+ data_dir: Optional[Path] = attr.ib(default=None)
+ num_features: int = attr.ib(default=1000)
+ tokens: str = attr.ib(default="iamdb_1kwp_tokens_1000.txt")
+ lexicon: str = attr.ib(default="iamdb_1kwp_lex_1000.txt")
+ use_words: bool = attr.ib(default=False)
+ prepend_wordsep: bool = attr.ib(default=False)
+ special_tokens: Set[str] = attr.ib(default={"<s>", "<e>", "<p>"}, converter=set)
+ extra_symbols: Set[str] = attr.ib(default={"\n",}, converter=set)
+ wordpiece_processor: Preprocessor = attr.ib(init=False)
+
+ def __attrs_post_init__(self) -> None:
+ super().__attrs_post_init__()
+ self.data_dir = (
(
Path(__file__).resolve().parents[2]
/ "data"
@@ -97,32 +84,32 @@ class WordPieceMapping(EmnistMapping):
/ "iam"
/ "iamdb"
)
- if data_dir is None
- else Path(data_dir)
+ if self.data_dir is None
+ else Path(self.data_dir)
)
-
- logger.debug(f"Using data dir: {data_dir}")
- if not data_dir.exists():
- raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ logger.debug(f"Using data dir: {self.data_dir}")
+ if not self.data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {self.data_dir}")
processed_path = (
Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
)
- tokens_path = processed_path / tokens
- lexicon_path = processed_path / lexicon
-
- if extra_symbols is not None:
- special_tokens += extra_symbols
-
- return Preprocessor(
- data_dir,
- num_features,
- tokens_path,
- lexicon_path,
- use_words,
- prepend_wordsep,
- special_tokens,
+ tokens_path = processed_path / self.tokens
+ lexicon_path = processed_path / self.lexicon
+
+ special_tokens = self.special_tokens
+ if self.extra_symbols is not None:
+ special_tokens = special_tokens | self.extra_symbols
+
+ self.wordpiece_processor = Preprocessor(
+ data_dir=self.data_dir,
+ num_features=self.num_features,
+ tokens_path=tokens_path,
+ lexicon_path=lexicon_path,
+ use_words=self.use_words,
+ prepend_wordsep=self.prepend_wordsep,
+ special_tokens=special_tokens,
)
def __len__(self) -> int:
@@ -151,7 +138,9 @@ class WordPieceMapping(EmnistMapping):
text = text.lower().replace(" ", "▁")
return torch.LongTensor(self.wordpiece_processor.to_index(text))
- def __getitem__(self, x: Union[str, int, Tensor]) -> Union[str, Tensor]:
+ def __getitem__(self, x: Union[str, int, List[int], Tensor]) -> Union[str, Tensor]:
+ if isinstance(x, int):
+ x = [x]
if isinstance(x, str):
- return self.get_index(x)
- return self.get_token(x)
+ return self.get_indices(x)
+ return self.get_text(x)