summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/transforms.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
commit7e8e54e84c63171e748bbf09516fd517e6821ace (patch)
tree996093f75a5d488dddf7ea1f159ed343a561ef89 /src/text_recognizer/datasets/transforms.py
parentb0719d84138b6bbe5f04a4982dfca673aea1a368 (diff)
Inital commit for refactoring to lightning
Diffstat (limited to 'src/text_recognizer/datasets/transforms.py')
-rw-r--r--src/text_recognizer/datasets/transforms.py266
1 files changed, 0 insertions, 266 deletions
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
deleted file mode 100644
index b6a48f5..0000000
--- a/src/text_recognizer/datasets/transforms.py
+++ /dev/null
@@ -1,266 +0,0 @@
-"""Transforms for PyTorch datasets."""
-from abc import abstractmethod
-from pathlib import Path
-import random
-from typing import Any, Optional, Union
-
-from loguru import logger
-import numpy as np
-from PIL import Image
-import torch
-from torch import Tensor
-import torch.nn.functional as F
-from torchvision import transforms
-from torchvision.transforms import (
- ColorJitter,
- Compose,
- Normalize,
- RandomAffine,
- RandomHorizontalFlip,
- RandomRotation,
- ToPILImage,
- ToTensor,
-)
-
-from text_recognizer.datasets.iam_preprocessor import Preprocessor
-from text_recognizer.datasets.util import EmnistMapper
-
-
-class RandomResizeCrop:
- """Image transform with random resize and crop applied.
-
- Stolen from
-
- https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
-
- """
-
- def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None:
- self.jitter = jitter
- self.ratio = ratio
-
- def __call__(self, img: np.ndarray) -> np.ndarray:
- """Applies random crop and rotation to an image."""
- w, h = img.size
-
- # pad with white:
- img = transforms.functional.pad(img, self.jitter, fill=255)
-
- # crop at random (x, y):
- x = self.jitter + random.randint(-self.jitter, self.jitter)
- y = self.jitter + random.randint(-self.jitter, self.jitter)
-
- # randomize aspect ratio:
- size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio)
- size = (h, int(size_w))
- img = transforms.functional.resized_crop(img, y, x, h, w, size)
- return img
-
-
-class Transpose:
- """Transposes the EMNIST image to the correct orientation."""
-
- def __call__(self, image: Image) -> np.ndarray:
- """Swaps axis."""
- return np.array(image).swapaxes(0, 1)
-
-
-class Resize:
- """Resizes a tensor to a specified width."""
-
- def __init__(self, width: int = 952) -> None:
- # The default is 952 because of the IAM dataset.
- self.width = width
-
- def __call__(self, image: Tensor) -> Tensor:
- """Resize tensor in the last dimension."""
- return F.interpolate(image, size=self.width, mode="nearest")
-
-
-class AddTokens:
- """Adds start of sequence and end of sequence tokens to target tensor."""
-
- def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None:
- self.init_token = init_token
- self.pad_token = pad_token
- self.eos_token = eos_token
- if self.init_token is not None:
- self.emnist_mapper = EmnistMapper(
- init_token=self.init_token,
- pad_token=self.pad_token,
- eos_token=self.eos_token,
- )
- else:
- self.emnist_mapper = EmnistMapper(
- pad_token=self.pad_token, eos_token=self.eos_token,
- )
- self.pad_value = self.emnist_mapper(self.pad_token)
- self.eos_value = self.emnist_mapper(self.eos_token)
-
- def __call__(self, target: Tensor) -> Tensor:
- """Adds a sos token to the begining and a eos token to the end of a target sequence."""
- dtype, device = target.dtype, target.device
-
- # Find the where padding starts.
- pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item()
-
- target[pad_index] = self.eos_value
-
- if self.init_token is not None:
- self.sos_value = self.emnist_mapper(self.init_token)
- sos = torch.tensor([self.sos_value], dtype=dtype, device=device)
- target = torch.cat([sos, target], dim=0)
-
- return target
-
-
-class ApplyContrast:
- """Sets everything below a threshold to zero, i.e. increase contrast."""
-
- def __init__(self, low: float = 0.0, high: float = 0.25) -> None:
- self.low = low
- self.high = high
-
- def __call__(self, x: Tensor) -> Tensor:
- """Apply mask binary mask to input tensor."""
- mask = x > np.random.RandomState().uniform(low=self.low, high=self.high)
- return x * mask
-
-
-class Unsqueeze:
- """Add a dimension to the tensor."""
-
- def __call__(self, x: Tensor) -> Tensor:
- """Adds dim."""
- return x.unsqueeze(0)
-
-
-class Squeeze:
- """Removes the first dimension of a tensor."""
-
- def __call__(self, x: Tensor) -> Tensor:
- """Removes first dim."""
- return x.squeeze(0)
-
-
-class ToLower:
- """Converts target to lower case."""
-
- def __call__(self, target: Tensor) -> Tensor:
- """Corrects index value in target tensor."""
- device = target.device
- return torch.stack([x - 26 if x > 35 else x for x in target]).to(device)
-
-
-class ToCharcters:
- """Converts integers to characters."""
-
- def __init__(
- self, pad_token: str, eos_token: str, init_token: str = None, lower: bool = True
- ) -> None:
- self.init_token = init_token
- self.pad_token = pad_token
- self.eos_token = eos_token
- if self.init_token is not None:
- self.emnist_mapper = EmnistMapper(
- init_token=self.init_token,
- pad_token=self.pad_token,
- eos_token=self.eos_token,
- lower=lower,
- )
- else:
- self.emnist_mapper = EmnistMapper(
- pad_token=self.pad_token, eos_token=self.eos_token, lower=lower
- )
-
- def __call__(self, y: Tensor) -> str:
- """Converts a Tensor to a str."""
- return (
- "".join([self.emnist_mapper(int(i)) for i in y])
- .strip("_")
- .replace(" ", "▁")
- )
-
-
-class WordPieces:
- """Abstract transform for word pieces."""
-
- def __init__(
- self,
- num_features: int,
- data_dir: Optional[Union[str, Path]] = None,
- tokens: Optional[Union[str, Path]] = None,
- lexicon: Optional[Union[str, Path]] = None,
- use_words: bool = False,
- prepend_wordsep: bool = False,
- ) -> None:
- if data_dir is None:
- data_dir = (
- Path(__file__).resolve().parents[3] / "data" / "raw" / "iam" / "iamdb"
- )
- logger.debug(f"Using data dir: {data_dir}")
- if not data_dir.exists():
- raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
- else:
- data_dir = Path(data_dir)
- processed_path = (
- Path(__file__).resolve().parents[3] / "data" / "processed" / "iam_lines"
- )
- tokens_path = processed_path / tokens
- lexicon_path = processed_path / lexicon
-
- self.preprocessor = Preprocessor(
- data_dir,
- num_features,
- tokens_path,
- lexicon_path,
- use_words,
- prepend_wordsep,
- )
-
- @abstractmethod
- def __call__(self, *args, **kwargs) -> Any:
- """Transforms input."""
- ...
-
-
-class ToWordPieces(WordPieces):
- """Transforms str to word pieces."""
-
- def __init__(
- self,
- num_features: int,
- data_dir: Optional[Union[str, Path]] = None,
- tokens: Optional[Union[str, Path]] = None,
- lexicon: Optional[Union[str, Path]] = None,
- use_words: bool = False,
- prepend_wordsep: bool = False,
- ) -> None:
- super().__init__(
- num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
- )
-
- def __call__(self, line: str) -> Tensor:
- """Transforms str to word pieces."""
- return self.preprocessor.to_index(line)
-
-
-class ToText(WordPieces):
- """Takes word pieces and converts them to text."""
-
- def __init__(
- self,
- num_features: int,
- data_dir: Optional[Union[str, Path]] = None,
- tokens: Optional[Union[str, Path]] = None,
- lexicon: Optional[Union[str, Path]] = None,
- use_words: bool = False,
- prepend_wordsep: bool = False,
- ) -> None:
- super().__init__(
- num_features, data_dir, tokens, lexicon, use_words, prepend_wordsep
- )
-
- def __call__(self, x: Tensor) -> str:
- """Converts tensor to text."""
- return self.preprocessor.to_text(x.tolist())