diff options
Diffstat (limited to 'src/text_recognizer/datasets/transforms.py')
-rw-r--r-- | src/text_recognizer/datasets/transforms.py | 45 |
1 files changed, 44 insertions, 1 deletions
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py index 8956b01..60987e0 100644 --- a/src/text_recognizer/datasets/transforms.py +++ b/src/text_recognizer/datasets/transforms.py @@ -1,14 +1,57 @@ """Transforms for PyTorch datasets.""" +import random + import numpy as np from PIL import Image import torch from torch import Tensor import torch.nn.functional as F -from torchvision.transforms import Compose, RandomAffine, RandomHorizontalFlip, ToTensor +from torchvision import transforms +from torchvision.transforms import ( + ColorJitter, + Compose, + Normalize, + RandomAffine, + RandomHorizontalFlip, + RandomRotation, + ToPILImage, + ToTensor, +) from text_recognizer.datasets.util import EmnistMapper +class RandomResizeCrop: + """Image transform with random resize and crop applied. + + Stolen from + + https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py + + """ + + def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None: + self.jitter = jitter + self.ratio = ratio + + def __call__(self, img: np.ndarray) -> np.ndarray: + """Applies random crop and rotation to an image.""" + w, h = img.size + + # pad with white: + img = transforms.functional.pad(img, self.jitter, fill=255) + + # crop at random (x, y): + x = self.jitter + random.randint(-self.jitter, self.jitter) + y = self.jitter + random.randint(-self.jitter, self.jitter) + + # randomize aspect ratio: + size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio) + size = (h, int(size_w)) + img = transforms.functional.resized_crop(img, y, x, h, w, size) + return img + + class Transpose: """Transposes the EMNIST image to the correct orientation.""" |