summaryrefslogtreecommitdiff
path: root/text_recognizer/data/transforms.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/transforms.py')
-rw-r--r--text_recognizer/data/transforms.py111
1 files changed, 16 insertions, 95 deletions
diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py
index 297c953..f53df64 100644
--- a/text_recognizer/data/transforms.py
+++ b/text_recognizer/data/transforms.py
@@ -1,115 +1,36 @@
"""Transforms for PyTorch datasets."""
-from abc import abstractmethod
from pathlib import Path
-from typing import Any, Optional, Union
+from typing import Optional, Union, Sequence
-from loguru import logger
-import torch
from torch import Tensor
-from text_recognizer.datasets.iam_preprocessor import Preprocessor
-from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.datasets.mappings import WordPieceMapping
-class ToLower:
- """Converts target to lower case."""
-
- def __call__(self, target: Tensor) -> Tensor:
- """Corrects index value in target tensor."""
- device = target.device
- return torch.stack([x - 26 if x > 35 else x for x in target]).to(device)
-
-
-class ToCharcters:
- """Converts integers to characters."""
-
- def __init__(self, extra_symbols: Optional[List[str]] = None) -> None:
- self.mapping, _, _ = emnist_mapping(extra_symbols)
-
- def __call__(self, y: Tensor) -> str:
- """Converts a Tensor to a str."""
- return "".join([self.mapping[int(i)] for i in y]).replace(" ", "▁")
-
-
-class WordPieces:
- """Abstract transform for word pieces."""
+class WordPiece:
+ """Converts EMNIST indices to Word Piece indices."""
def __init__(
self,
num_features: int,
+ tokens: str,
+ lexicon: str,
data_dir: Optional[Union[str, Path]] = None,
- tokens: Optional[Union[str, Path]] = None,
- lexicon: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
+ special_tokens: Sequence[str] = ("<s>", "<e>", "<p>"),
+ extra_symbols: Optional[Sequence[str]] = None,
) -> None:
- if data_dir is None:
- data_dir = (
- Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb"
- )
- logger.debug(f"Using data dir: {data_dir}")
- if not data_dir.exists():
- raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
- else:
- data_dir = Path(data_dir)
- processed_path = (
- Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines"
- )
- tokens_path = processed_path / tokens
- lexicon_path = processed_path / lexicon
-
- self.preprocessor = Preprocessor(
- data_dir,
+ self.mapping = WordPieceMapping(
num_features,
- tokens_path,
- lexicon_path,
+ tokens,
+ lexicon,
+ data_dir,
use_words,
prepend_wordsep,
+ special_tokens,
+ extra_symbols,
)
- @abstractmethod
- def __call__(self, *args, **kwargs) -> Any:
- """Transforms input."""
- ...
-
-
-class ToWordPieces(WordPieces):
- """Transforms str to word pieces."""
-
- def __init__(
- self,
- num_features: int,
- data_dir: Optional[Union[str, Path]] = None,
- tokens: Optional[Union[str, Path]] = None,
- lexicon: Optional[Union[str, Path]] = None,
- use_words: bool = False,
- prepend_wordsep: bool = False,
- ) -> None:
- super().__init__(
- num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
- )
-
- def __call__(self, line: str) -> Tensor:
- """Transforms str to word pieces."""
- return self.preprocessor.to_index(line)
-
-
-class ToText(WordPieces):
- """Takes word pieces and converts them to text."""
-
- def __init__(
- self,
- num_features: int,
- data_dir: Optional[Union[str, Path]] = None,
- tokens: Optional[Union[str, Path]] = None,
- lexicon: Optional[Union[str, Path]] = None,
- use_words: bool = False,
- prepend_wordsep: bool = False,
- ) -> None:
- super().__init__(
- num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
- )
-
- def __call__(self, x: Tensor) -> str:
- """Converts tensor to text."""
- return self.preprocessor.to_text(x.tolist())
+ def __call__(self, x: Tensor) -> Tensor:
+ return self.mapping.emnist_to_wordpiece_indices(x)