From 7268035fb9e57342612a8cc50a1fe04e8841ca2f Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Fri, 30 Jul 2021 23:15:03 +0200 Subject: attr bug fix, properly loading network --- text_recognizer/data/mappings.py | 111 ++++++++++++++++++--------------------- 1 file changed, 50 insertions(+), 61 deletions(-) (limited to 'text_recognizer/data/mappings.py') diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py index 0d778b2..a934fd9 100644 --- a/text_recognizer/data/mappings.py +++ b/text_recognizer/data/mappings.py @@ -1,8 +1,9 @@ """Mapping to and from word pieces.""" from abc import ABC, abstractmethod from pathlib import Path -from typing import List, Optional, Union, Sequence +from typing import Dict, List, Optional, Union, Set, Sequence +import attr from loguru import logger import torch from torch import Tensor @@ -29,10 +30,17 @@ class AbstractMapping(ABC): ... +@attr.s class EmnistMapping(AbstractMapping): - def __init__(self, extra_symbols: Optional[Sequence[str]]) -> None: + extra_symbols: Optional[Set[str]] = attr.ib(default=None, converter=set) + mapping: Sequence[str] = attr.ib(init=False) + inverse_mapping: Dict[str, int] = attr.ib(init=False) + input_size: List[int] = attr.ib(init=False) + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" self.mapping, self.inverse_mapping, self.input_size = emnist_mapping( - extra_symbols + self.extra_symbols ) def get_token(self, index: Union[int, Tensor]) -> str: @@ -54,42 +62,21 @@ class EmnistMapping(AbstractMapping): return Tensor([self.inverse_mapping[token] for token in text]) +@attr.s(auto_attribs=True) class WordPieceMapping(EmnistMapping): - def __init__( - self, - num_features: int = 1000, - tokens: str = "iamdb_1kwp_tokens_1000.txt", - lexicon: str = "iamdb_1kwp_lex_1000.txt", - data_dir: Optional[Union[str, Path]] = None, - use_words: bool = False, - prepend_wordsep: bool = False, - special_tokens: Sequence[str] = ("", "", "

"), - extra_symbols: Optional[Sequence[str]] = ("\n",), - ) -> None: - super().__init__(extra_symbols) - self.wordpiece_processor = self._configure_wordpiece_processor( - num_features, - tokens, - lexicon, - data_dir, - use_words, - prepend_wordsep, - special_tokens, - extra_symbols, - ) - - @staticmethod - def _configure_wordpiece_processor( - num_features: int, - tokens: str, - lexicon: str, - data_dir: Optional[Union[str, Path]], - use_words: bool, - prepend_wordsep: bool, - special_tokens: Optional[Sequence[str]], - extra_symbols: Optional[Sequence[str]], - ) -> Preprocessor: - data_dir = ( + data_dir: Optional[Path] = attr.ib(default=None) + num_features: int = attr.ib(default=1000) + tokens: str = attr.ib(default="iamdb_1kwp_tokens_1000.txt") + lexicon: str = attr.ib(default="iamdb_1kwp_lex_1000.txt") + use_words: bool = attr.ib(default=False) + prepend_wordsep: bool = attr.ib(default=False) + special_tokens: Set[str] = attr.ib(default={"", "", "

"}, converter=set) + extra_symbols: Set[str] = attr.ib(default={"\n",}, converter=set) + wordpiece_processor: Preprocessor = attr.ib(init=False) + + def __attrs_post_init__(self) -> None: + super().__attrs_post_init__() + self.data_dir = ( ( Path(__file__).resolve().parents[2] / "data" @@ -97,32 +84,32 @@ class WordPieceMapping(EmnistMapping): / "iam" / "iamdb" ) - if data_dir is None - else Path(data_dir) + if self.data_dir is None + else Path(self.data_dir) ) - - logger.debug(f"Using data dir: {data_dir}") - if not data_dir.exists(): - raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") + logger.debug(f"Using data dir: {self.data_dir}") + if not self.data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {self.data_dir}") processed_path = ( Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines" ) - tokens_path = processed_path / tokens - lexicon_path = processed_path / lexicon - - if extra_symbols is not None: - special_tokens += extra_symbols - - return Preprocessor( - data_dir, - num_features, - tokens_path, - lexicon_path, - use_words, - prepend_wordsep, - special_tokens, + tokens_path = processed_path / self.tokens + lexicon_path = processed_path / self.lexicon + + special_tokens = self.special_tokens + if self.extra_symbols is not None: + special_tokens = special_tokens | self.extra_symbols + + self.wordpiece_processor = Preprocessor( + data_dir=self.data_dir, + num_features=self.num_features, + tokens_path=tokens_path, + lexicon_path=lexicon_path, + use_words=self.use_words, + prepend_wordsep=self.prepend_wordsep, + special_tokens=special_tokens, ) def __len__(self) -> int: @@ -151,7 +138,9 @@ class WordPieceMapping(EmnistMapping): text = text.lower().replace(" ", "▁") return torch.LongTensor(self.wordpiece_processor.to_index(text)) - def __getitem__(self, x: Union[str, int, Tensor]) -> Union[str, Tensor]: + def __getitem__(self, x: Union[str, int, List[int], Tensor]) -> Union[str, Tensor]: + if isinstance(x, int): + x = [x] if isinstance(x, str): - return self.get_index(x) - return self.get_token(x) + return self.get_indices(x) + return self.get_text(x) -- cgit v1.2.3-70-g09d2