summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/transforms.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/datasets/transforms.py')
-rw-r--r--src/text_recognizer/datasets/transforms.py119
1 files changed, 119 insertions, 0 deletions
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
index 60987e0..b6a48f5 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/src/text_recognizer/datasets/transforms.py
@@ -1,6 +1,10 @@
"""Transforms for PyTorch datasets."""
+from abc import abstractmethod
+from pathlib import Path
import random
+from typing import Any, Optional, Union
+from loguru import logger
import numpy as np
from PIL import Image
import torch
@@ -18,6 +22,7 @@ from torchvision.transforms import (
ToTensor,
)
+from text_recognizer.datasets.iam_preprocessor import Preprocessor
from text_recognizer.datasets.util import EmnistMapper
@@ -145,3 +150,117 @@ class ToLower:
"""Corrects index value in target tensor."""
device = target.device
return torch.stack([x - 26 if x > 35 else x for x in target]).to(device)
+
+
+class ToCharcters:
+ """Converts integers to characters."""
+
+ def __init__(
+ self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True
+ ) -> None:
+ self.init_token = init_token
+ self.pad_token = pad_token
+ self.eos_token = eos_token
+ if self.init_token is not None:
+ self.emnist_mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ lower=lower,
+ )
+ else:
+ self.emnist_mapper = EmnistMapper(
+ pad_token=self.pad_token, eos_token=self.eos_token, lower=lower
+ )
+
+ def __call__(self, y: Tensor) -> str:
+ """Converts a Tensor to a str."""
+ return (
+ "".join([self.emnist_mapper(int(i)) for i in y])
+ .strip("_")
+ .replace(" ", "▁")
+ )
+
+
+class WordPieces:
+ """Abstract transform for word pieces."""
+
+ def __init__(
+ self,
+ num_features: int,
+ data_dir: Optional[Union[str, Path]] = None,
+ tokens: Optional[Union[str, Path]] = None,
+ lexicon: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ if data_dir is None:
+ data_dir = (
+ Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb"
+ )
+ logger.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ else:
+ data_dir = Path(data_dir)
+ processed_path = (
+ Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines"
+ )
+ tokens_path = processed_path / tokens
+ lexicon_path = processed_path / lexicon
+
+ self.preprocessor = Preprocessor(
+ data_dir,
+ num_features,
+ tokens_path,
+ lexicon_path,
+ use_words,
+ prepend_wordsep,
+ )
+
+ @abstractmethod
+ def __call__(self, *args, **kwargs) -> Any:
+ """Transforms input."""
+ ...
+
+
+class ToWordPieces(WordPieces):
+ """Transforms str to word pieces."""
+
+ def __init__(
+ self,
+ num_features: int,
+ data_dir: Optional[Union[str, Path]] = None,
+ tokens: Optional[Union[str, Path]] = None,
+ lexicon: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ super().__init__(
+ num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
+ )
+
+ def __call__(self, line: str) -> Tensor:
+ """Transforms str to word pieces."""
+ return self.preprocessor.to_index(line)
+
+
+class ToText(WordPieces):
+ """Takes word pieces and converts them to text."""
+
+ def __init__(
+ self,
+ num_features: int,
+ data_dir: Optional[Union[str, Path]] = None,
+ tokens: Optional[Union[str, Path]] = None,
+ lexicon: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ super().__init__(
+ num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
+ )
+
+ def __call__(self, x: Tensor) -> str:
+ """Converts tensor to text."""
+ return self.preprocessor.to_text(x.tolist())