summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2021-02-24 22:00:29 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2021-02-24 22:00:29 +0100
commit905eeeb4c3c0ba54b5414eb8f435e2e9870b7307 (patch)
tree91dab598a94911e6147b996237e786dd47f11f2f /src/text_recognizer/datasets
parent4a54d7e690897dd6e6c719fb908fd371a44c2952 (diff)
updates
Diffstat (limited to 'src/text_recognizer/datasets')
-rw-r--r--src/text_recognizer/datasets/iam_preprocessor.py2
-rw-r--r--src/text_recognizer/datasets/transforms.py119
2 files changed, 120 insertions, 1 deletions
diff --git a/src/text_recognizer/datasets/iam_preprocessor.py b/src/text_recognizer/datasets/iam_preprocessor.py
index 5a5136c..a93eb00 100644
--- a/src/text_recognizer/datasets/iam_preprocessor.py
+++ b/src/text_recognizer/datasets/iam_preprocessor.py
@@ -59,7 +59,7 @@ class Preprocessor:
use_words: bool = False,
prepend_wordsep: bool = False,
) -> None:
- self.wordsep = "_"
+ self.wordsep = "▁"
self._use_word = use_words
self._prepend_wordsep = prepend_wordsep
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
index 60987e0..b6a48f5 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/src/text_recognizer/datasets/transforms.py
@@ -1,6 +1,10 @@
"""Transforms for PyTorch datasets."""
+from abc import abstractmethod
+from pathlib import Path
import random
+from typing import Any, Optional, Union
+from loguru import logger
import numpy as np
from PIL import Image
import torch
@@ -18,6 +22,7 @@ from torchvision.transforms import (
ToTensor,
)
+from text_recognizer.datasets.iam_preprocessor import Preprocessor
from text_recognizer.datasets.util import EmnistMapper
@@ -145,3 +150,117 @@ class ToLower:
"""Corrects index value in target tensor."""
device = target.device
return torch.stack([x - 26 if x > 35 else x for x in target]).to(device)
+
+
+class ToCharcters:
+ """Converts integers to characters."""
+
+ def __init__(
+ self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True
+ ) -> None:
+ self.init_token = init_token
+ self.pad_token = pad_token
+ self.eos_token = eos_token
+ if self.init_token is not None:
+ self.emnist_mapper = EmnistMapper(
+ init_token=self.init_token,
+ pad_token=self.pad_token,
+ eos_token=self.eos_token,
+ lower=lower,
+ )
+ else:
+ self.emnist_mapper = EmnistMapper(
+ pad_token=self.pad_token, eos_token=self.eos_token, lower=lower
+ )
+
+ def __call__(self, y: Tensor) -> str:
+ """Converts a Tensor to a str."""
+ return (
+ "".join([self.emnist_mapper(int(i)) for i in y])
+ .strip("_")
+ .replace(" ", "▁")
+ )
+
+
+class WordPieces:
+ """Abstract transform for word pieces."""
+
+ def __init__(
+ self,
+ num_features: int,
+ data_dir: Optional[Union[str, Path]] = None,
+ tokens: Optional[Union[str, Path]] = None,
+ lexicon: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ if data_dir is None:
+ data_dir = (
+ Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb"
+ )
+ logger.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ else:
+ data_dir = Path(data_dir)
+ processed_path = (
+ Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines"
+ )
+ tokens_path = processed_path / tokens
+ lexicon_path = processed_path / lexicon
+
+ self.preprocessor = Preprocessor(
+ data_dir,
+ num_features,
+ tokens_path,
+ lexicon_path,
+ use_words,
+ prepend_wordsep,
+ )
+
+ @abstractmethod
+ def __call__(self, *args, **kwargs) -> Any:
+ """Transforms input."""
+ ...
+
+
+class ToWordPieces(WordPieces):
+ """Transforms str to word pieces."""
+
+ def __init__(
+ self,
+ num_features: int,
+ data_dir: Optional[Union[str, Path]] = None,
+ tokens: Optional[Union[str, Path]] = None,
+ lexicon: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ super().__init__(
+ num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
+ )
+
+ def __call__(self, line: str) -> Tensor:
+ """Transforms str to word pieces."""
+ return self.preprocessor.to_index(line)
+
+
+class ToText(WordPieces):
+ """Takes word pieces and converts them to text."""
+
+ def __init__(
+ self,
+ num_features: int,
+ data_dir: Optional[Union[str, Path]] = None,
+ tokens: Optional[Union[str, Path]] = None,
+ lexicon: Optional[Union[str, Path]] = None,
+ use_words: bool = False,
+ prepend_wordsep: bool = False,
+ ) -> None:
+ super().__init__(
+ num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
+ )
+
+ def __call__(self, x: Tensor) -> str:
+ """Converts tensor to text."""
+ return self.preprocessor.to_text(x.tolist())