From 905eeeb4c3c0ba54b5414eb8f435e2e9870b7307 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Wed, 24 Feb 2021 22:00:29 +0100 Subject: updates --- src/text_recognizer/datasets/iam_preprocessor.py | 2 +- src/text_recognizer/datasets/transforms.py | 119 +++++++++++++++++++++++ 2 files changed, 120 insertions(+), 1 deletion(-) (limited to 'src/text_recognizer/datasets') diff --git a/src/text_recognizer/datasets/iam_preprocessor.py b/src/text_recognizer/datasets/iam_preprocessor.py index 5a5136c..a93eb00 100644 --- a/src/text_recognizer/datasets/iam_preprocessor.py +++ b/src/text_recognizer/datasets/iam_preprocessor.py @@ -59,7 +59,7 @@ class Preprocessor: use_words: bool = False, prepend_wordsep: bool = False, ) -> None: - self.wordsep = "_" + self.wordsep = "▁" self._use_word = use_words self._prepend_wordsep = prepend_wordsep diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py index 60987e0..b6a48f5 100644 --- a/src/text_recognizer/datasets/transforms.py +++ b/src/text_recognizer/datasets/transforms.py @@ -1,6 +1,10 @@ """Transforms for PyTorch datasets.""" +from abc import abstractmethod +from pathlib import Path import random +from typing import Any, Optional, Union +from loguru import logger import numpy as np from PIL import Image import torch @@ -18,6 +22,7 @@ from torchvision.transforms import ( ToTensor, ) +from text_recognizer.datasets.iam_preprocessor import Preprocessor from text_recognizer.datasets.util import EmnistMapper @@ -145,3 +150,117 @@ class ToLower: """Corrects index value in target tensor.""" device = target.device return torch.stack([x - 26 if x > 35 else x for x in target]).to(device) + + +class ToCharcters: + """Converts integers to characters.""" + + def __init__( + self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True + ) -> None: + self.init_token = init_token + self.pad_token = pad_token + self.eos_token = eos_token + if self.init_token is not None: + self.emnist_mapper = EmnistMapper( + init_token=self.init_token, + pad_token=self.pad_token, + eos_token=self.eos_token, + lower=lower, + ) + else: + self.emnist_mapper = EmnistMapper( + pad_token=self.pad_token, eos_token=self.eos_token, lower=lower + ) + + def __call__(self, y: Tensor) -> str: + """Converts a Tensor to a str.""" + return ( + "".join([self.emnist_mapper(int(i)) for i in y]) + .strip("_") + .replace(" ", "▁") + ) + + +class WordPieces: + """Abstract transform for word pieces.""" + + def __init__( + self, + num_features: int, + data_dir: Optional[Union[str, Path]] = None, + tokens: Optional[Union[str, Path]] = None, + lexicon: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + if data_dir is None: + data_dir = ( + Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb" + ) + logger.debug(f"Using data dir: {data_dir}") + if not data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") + else: + data_dir = Path(data_dir) + processed_path = ( + Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines" + ) + tokens_path = processed_path / tokens + lexicon_path = processed_path / lexicon + + self.preprocessor = Preprocessor( + data_dir, + num_features, + tokens_path, + lexicon_path, + use_words, + prepend_wordsep, + ) + + @abstractmethod + def __call__(self, *args, **kwargs) -> Any: + """Transforms input.""" + ... + + +class ToWordPieces(WordPieces): + """Transforms str to word pieces.""" + + def __init__( + self, + num_features: int, + data_dir: Optional[Union[str, Path]] = None, + tokens: Optional[Union[str, Path]] = None, + lexicon: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + super().__init__( + num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep + ) + + def __call__(self, line: str) -> Tensor: + """Transforms str to word pieces.""" + return self.preprocessor.to_index(line) + + +class ToText(WordPieces): + """Takes word pieces and converts them to text.""" + + def __init__( + self, + num_features: int, + data_dir: Optional[Union[str, Path]] = None, + tokens: Optional[Union[str, Path]] = None, + lexicon: Optional[Union[str, Path]] = None, + use_words: bool = False, + prepend_wordsep: bool = False, + ) -> None: + super().__init__( + num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep + ) + + def __call__(self, x: Tensor) -> str: + """Converts tensor to text.""" + return self.preprocessor.to_text(x.tolist()) -- cgit v1.2.3-70-g09d2