diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/data/iam_preprocessor.py | 1 | ||||
-rw-r--r-- | text_recognizer/data/mapping.py | 8 | ||||
-rw-r--r-- | text_recognizer/data/mappings.py | 143 | ||||
-rw-r--r-- | text_recognizer/data/transforms.py | 111 | ||||
-rw-r--r-- | text_recognizer/models/base.py | 6 | ||||
-rw-r--r-- | text_recognizer/networks/image_transformer.py | 4 |
6 files changed, 166 insertions, 107 deletions
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py index d85787e..60f8a9f 100644 --- a/text_recognizer/data/iam_preprocessor.py +++ b/text_recognizer/data/iam_preprocessor.py @@ -119,7 +119,6 @@ class Preprocessor: continue self.text.append(example["text"].lower()) - def _to_index(self, line: str) -> torch.LongTensor: if line in self.special_tokens: return torch.LongTensor([self.tokens_to_index[line]]) diff --git a/text_recognizer/data/mapping.py b/text_recognizer/data/mapping.py deleted file mode 100644 index f0edf3f..0000000 --- a/text_recognizer/data/mapping.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Mapping to and from word pieces.""" -from pathlib import Path - - -class WordPieces: - - def __init__(self) -> None: - pass diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py new file mode 100644 index 0000000..cfa0ec7 --- /dev/null +++ b/text_recognizer/data/mappings.py @@ -0,0 +1,143 @@ +"""Mapping to and from word pieces.""" +from abc import ABC, abstractmethod +from pathlib import Path +from typing import List, Optional, Union, Sequence + +from loguru import logger +import torch +from torch import Tensor + +from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.datasets.iam_preprocessor import Preprocessor + + +class AbstractMapping(ABC): + @abstractmethod + def get_token(self, *args, **kwargs) -> str: + ... + + @abstractmethod + def get_index(self, *args, **kwargs) -> Tensor: + ... + + @abstractmethod + def get_text(self, *args, **kwargs) -> str: + ... + + @abstractmethod + def get_indices(self, *args, **kwargs) -> Tensor: + ... + + +class EmnistMapping(AbstractMapping): + def __init__(self, extra_symbols: Optional[Sequence[str]]) -> None: + self.mapping, self.inverse_mapping, self.input_size = emnist_mapping( + extra_symbols + ) + + def get_token(self, index: Union[int, Tensor]) -> str: + if (index := int(index)) in self.mapping: + return self.mapping[index] + raise KeyError(f"Index ({index}) not in mapping.") + + def get_index(self, token: str) -> Tensor: + if token in self.inverse_mapping: + return Tensor(self.inverse_mapping[token]) + raise KeyError(f"Token ({token}) not found in inverse mapping.") + + def get_text(self, indices: Union[List[int], Tensor]) -> str: + if isinstance(indices, Tensor): + indices = indices.tolist() + return "".join([self.mapping[index] for index in indices]) + + def get_indices(self, text: str) -> Tensor: + return Tensor([self.inverse_mapping[token] for token in text]) + + +class WordPieceMapping(EmnistMapping): + def __init__( + self, + num_features: int, + tokens: str, + lexicon: str, + data_dir: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"), + extra_symbols: Optional[Sequence[str]] = None, + ) -> None: + super().__init__(extra_symbols) + self.wordpiece_processor = self._configure_wordpiece_processor( + num_features, + tokens, + lexicon, + data_dir, + use_words, + prepend_wordsep, + special_tokens, + extra_symbols, + ) + + def _configure_wordpiece_processor( + self, + num_features: int, + tokens: str, + lexicon: str, + data_dir: Optional[Union[str, Path]], + use_words: bool, + prepend_wordsep: bool, + special_tokens: Optional[Sequence[str]], + extra_symbols: Optional[Sequence[str]], + ) -> Preprocessor: + data_dir = ( + (Path(__file__).resolve().parents[2] / "data" / "raw" / "iam" / "iamdb") + if data_dir is None + else Path(data_dir) + ) + + logger.debug(f"Using data dir: {data_dir}") + if not data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") + + processed_path = ( + Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines" + ) + + tokens_path = processed_path / tokens + lexicon_path = processed_path / lexicon + + if extra_symbols is not None: + special_tokens += extra_symbols + + return Preprocessor( + data_dir, + num_features, + tokens_path, + lexicon_path, + use_words, + prepend_wordsep, + special_tokens, + ) + + def get_token(self, index: Union[int, Tensor]) -> str: + if (index := int(index)) <= self.wordpiece_processor.num_tokens: + return self.wordpiece_processor.tokens[index] + raise KeyError(f"Index ({index}) not in mapping.") + + def get_index(self, token: str) -> Tensor: + if token in self.wordpiece_processor.tokens: + return torch.LongTensor(self.wordpiece_processor.tokens_to_index[token]) + raise KeyError(f"Token ({token}) not found in inverse mapping.") + + def get_text(self, indices: Union[List[int], Tensor]) -> str: + if isinstance(indices, Tensor): + indices = indices.tolist() + return self.wordpiece_processor.to_text(indices) + + def get_indices(self, text: str) -> Tensor: + return self.wordpiece_processor.to_index(text) + + def emnist_to_wordpiece_indices(self, x: Tensor) -> Tensor: + text = self.mapping.get_text(x) + text = text.lower().replace(" ", "▁") + return torch.LongTensor(self.wordpiece_processor.to_index(text)) diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index 297c953..f53df64 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -1,115 +1,36 @@ """Transforms for PyTorch datasets.""" -from abc import abstractmethod from pathlib import Path -from typing import Any, Optional, Union +from typing import Optional, Union, Sequence -from loguru import logger -import torch from torch import Tensor -from text_recognizer.datasets.iam_preprocessor import Preprocessor -from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.datasets.mappings import WordPieceMapping -class ToLower: - """Converts target to lower case.""" - - def __call__(self, target: Tensor) -> Tensor: - """Corrects index value in target tensor.""" - device = target.device - return torch.stack([x - 26 if x > 35 else x for x in target]).to(device) - - -class ToCharcters: - """Converts integers to characters.""" - - def __init__(self, extra_symbols: Optional[List[str]] = None) -> None: - self.mapping, _, _ = emnist_mapping(extra_symbols) - - def __call__(self, y: Tensor) -> str: - """Converts a Tensor to a str.""" - return "".join([self.mapping[int(i)] for i in y]).replace(" ", "▁") - - -class WordPieces: - """Abstract transform for word pieces.""" +class WordPiece: + """Converts EMNIST indices to Word Piece indices.""" def __init__( self, num_features: int, + tokens: str, + lexicon: str, data_dir: Optional[Union[str, Path]] = None, - tokens: Optional[Union[str, Path]] = None, - lexicon: Optional[Union[str, Path]] = None, use_words: bool = False, prepend_wordsep: bool = False, + special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"), + extra_symbols: Optional[Sequence[str]] = None, ) -> None: - if data_dir is None: - data_dir = ( - Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb" - ) - logger.debug(f"Using data dir: {data_dir}") - if not data_dir.exists(): - raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") - else: - data_dir = Path(data_dir) - processed_path = ( - Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines" - ) - tokens_path = processed_path / tokens - lexicon_path = processed_path / lexicon - - self.preprocessor = Preprocessor( - data_dir, + self.mapping = WordPieceMapping( num_features, - tokens_path, - lexicon_path, + tokens, + lexicon, + data_dir, use_words, prepend_wordsep, + special_tokens, + extra_symbols, ) - @abstractmethod - def __call__(self, *args, **kwargs) -> Any: - """Transforms input.""" - ... - - -class ToWordPieces(WordPieces): - """Transforms str to word pieces.""" - - def __init__( - self, - num_features: int, - data_dir: Optional[Union[str, Path]] = None, - tokens: Optional[Union[str, Path]] = None, - lexicon: Optional[Union[str, Path]] = None, - use_words: bool = False, - prepend_wordsep: bool = False, - ) -> None: - super().__init__( - num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep - ) - - def __call__(self, line: str) -> Tensor: - """Transforms str to word pieces.""" - return self.preprocessor.to_index(line) - - -class ToText(WordPieces): - """Takes word pieces and converts them to text.""" - - def __init__( - self, - num_features: int, - data_dir: Optional[Union[str, Path]] = None, - tokens: Optional[Union[str, Path]] = None, - lexicon: Optional[Union[str, Path]] = None, - use_words: bool = False, - prepend_wordsep: bool = False, - ) -> None: - super().__init__( - num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep - ) - - def __call__(self, x: Tensor) -> str: - """Converts tensor to text.""" - return self.preprocessor.to_text(x.tolist()) + def __call__(self, x: Tensor) -> Tensor: + return self.mapping.emnist_to_wordpiece_indices(x) diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index c6d5d73..aeda039 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -49,7 +49,9 @@ class LitBaseModel(pl.LightningModule): optimizer_class = getattr(torch.optim, self._optimizer.type) return optimizer_class(params=self.parameters(), **args) - def _configure_lr_scheduler(self, optimizer: Type[torch.optim.Optimizer]) -> Dict[str, Any]: + def _configure_lr_scheduler( + self, optimizer: Type[torch.optim.Optimizer] + ) -> Dict[str, Any]: """Configures the lr scheduler.""" scheduler = {"monitor": self.monitor} args = {} or self._lr_scheduler.args @@ -59,7 +61,7 @@ class LitBaseModel(pl.LightningModule): scheduler["scheduler"] = getattr( torch.optim.lr_scheduler, self._lr_scheduler.type - )(optimizer, **args) + )(optimizer, **args) return scheduler diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py index daededa..a6aaca4 100644 --- a/text_recognizer/networks/image_transformer.py +++ b/text_recognizer/networks/image_transformer.py @@ -44,7 +44,9 @@ class ImageTransformer(nn.Module): dropout_rate: float = 0.1, transformer_activation: str = "glu", ) -> None: - self.vocab_size = NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size + self.vocab_size = ( + NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size + ) self.hidden_dim = hidden_dim self.max_output_length = output_shape[0] |