From 905eeeb4c3c0ba54b5414eb8f435e2e9870b7307 Mon Sep 17 00:00:00 2001
From: aktersnurra <gustaf.rydholm@gmail.com>
Date: Wed, 24 Feb 2021 22:00:29 +0100
Subject: updates

---
 src/text_recognizer/datasets/iam_preprocessor.py |   2 +-
 src/text_recognizer/datasets/transforms.py       | 119 +++++++++++++++++++++++
 2 files changed, 120 insertions(+), 1 deletion(-)

(limited to 'src/text_recognizer/datasets')

diff --git a/src/text_recognizer/datasets/iam_preprocessor.py b/src/text_recognizer/datasets/iam_preprocessor.py
index 5a5136c..a93eb00 100644
--- a/src/text_recognizer/datasets/iam_preprocessor.py
+++ b/src/text_recognizer/datasets/iam_preprocessor.py
@@ -59,7 +59,7 @@ class Preprocessor:
         use_words: bool = False,
         prepend_wordsep: bool = False,
     ) -> None:
-        self.wordsep = "_"
+        self.wordsep = "▁"
         self._use_word = use_words
         self._prepend_wordsep = prepend_wordsep
 
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
index 60987e0..b6a48f5 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/src/text_recognizer/datasets/transforms.py
@@ -1,6 +1,10 @@
 """Transforms for PyTorch datasets."""
+from abc import abstractmethod
+from pathlib import Path
 import random
+from typing import Any, Optional, Union
 
+from loguru import logger
 import numpy as np
 from PIL import Image
 import torch
@@ -18,6 +22,7 @@ from torchvision.transforms import (
     ToTensor,
 )
 
+from text_recognizer.datasets.iam_preprocessor import Preprocessor
 from text_recognizer.datasets.util import EmnistMapper
 
 
@@ -145,3 +150,117 @@ class ToLower:
         """Corrects index value in target tensor."""
         device = target.device
         return torch.stack([x - 26 if x > 35 else x for x in target]).to(device)
+
+
+class ToCharcters:
+    """Converts integers to characters."""
+
+    def __init__(
+        self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True
+    ) -> None:
+        self.init_token = init_token
+        self.pad_token = pad_token
+        self.eos_token = eos_token
+        if self.init_token is not None:
+            self.emnist_mapper = EmnistMapper(
+                init_token=self.init_token,
+                pad_token=self.pad_token,
+                eos_token=self.eos_token,
+                lower=lower,
+            )
+        else:
+            self.emnist_mapper = EmnistMapper(
+                pad_token=self.pad_token, eos_token=self.eos_token, lower=lower
+            )
+
+    def __call__(self, y: Tensor) -> str:
+        """Converts a Tensor to a str."""
+        return (
+            "".join([self.emnist_mapper(int(i)) for i in y])
+            .strip("_")
+            .replace(" ", "▁")
+        )
+
+
+class WordPieces:
+    """Abstract transform for word pieces."""
+
+    def __init__(
+        self,
+        num_features: int,
+        data_dir: Optional[Union[str, Path]] = None,
+        tokens: Optional[Union[str, Path]] = None,
+        lexicon: Optional[Union[str, Path]] = None,
+        use_words: bool = False,
+        prepend_wordsep: bool = False,
+    ) -> None:
+        if data_dir is None:
+            data_dir = (
+                Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb"
+            )
+            logger.debug(f"Using data dir: {data_dir}")
+            if not data_dir.exists():
+                raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+        else:
+            data_dir = Path(data_dir)
+        processed_path = (
+            Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines"
+        )
+        tokens_path = processed_path / tokens
+        lexicon_path = processed_path / lexicon
+
+        self.preprocessor = Preprocessor(
+            data_dir,
+            num_features,
+            tokens_path,
+            lexicon_path,
+            use_words,
+            prepend_wordsep,
+        )
+
+    @abstractmethod
+    def __call__(self, *args, **kwargs) -> Any:
+        """Transforms input."""
+        ...
+
+
+class ToWordPieces(WordPieces):
+    """Transforms str to word pieces."""
+
+    def __init__(
+        self,
+        num_features: int,
+        data_dir: Optional[Union[str, Path]] = None,
+        tokens: Optional[Union[str, Path]] = None,
+        lexicon: Optional[Union[str, Path]] = None,
+        use_words: bool = False,
+        prepend_wordsep: bool = False,
+    ) -> None:
+        super().__init__(
+            num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
+        )
+
+    def __call__(self, line: str) -> Tensor:
+        """Transforms str to word pieces."""
+        return self.preprocessor.to_index(line)
+
+
+class ToText(WordPieces):
+    """Takes word pieces and converts them to text."""
+
+    def __init__(
+        self,
+        num_features: int,
+        data_dir: Optional[Union[str, Path]] = None,
+        tokens: Optional[Union[str, Path]] = None,
+        lexicon: Optional[Union[str, Path]] = None,
+        use_words: bool = False,
+        prepend_wordsep: bool = False,
+    ) -> None:
+        super().__init__(
+            num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
+        )
+
+    def __call__(self, x: Tensor) -> str:
+        """Converts tensor to text."""
+        return self.preprocessor.to_text(x.tolist())
-- 
cgit v1.2.3-70-g09d2