1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
|
"""Transforms for PyTorch datasets."""
import random
import numpy as np
from PIL import Image
import torch
from torch import Tensor
import torch.nn.functional as F
from torchvision import transforms
from torchvision.transforms import (
ColorJitter,
Compose,
Normalize,
RandomAffine,
RandomHorizontalFlip,
RandomRotation,
ToPILImage,
ToTensor,
)
from text_recognizer.datasets.util import EmnistMapper
class RandomResizeCrop:
"""Image transform with random resize and crop applied.
Stolen from
https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
"""
def __init__(self, jitter: int = 10, ratio: float = 0.5) -> None:
self.jitter = jitter
self.ratio = ratio
def __call__(self, img: np.ndarray) -> np.ndarray:
"""Applies random crop and rotation to an image."""
w, h = img.size
# pad with white:
img = transforms.functional.pad(img, self.jitter, fill=255)
# crop at random (x, y):
x = self.jitter + random.randint(-self.jitter, self.jitter)
y = self.jitter + random.randint(-self.jitter, self.jitter)
# randomize aspect ratio:
size_w = w * random.uniform(1 - self.ratio, 1 + self.ratio)
size = (h, int(size_w))
img = transforms.functional.resized_crop(img, y, x, h, w, size)
return img
class Transpose:
"""Transposes the EMNIST image to the correct orientation."""
def __call__(self, image: Image) -> np.ndarray:
"""Swaps axis."""
return np.array(image).swapaxes(0, 1)
class Resize:
"""Resizes a tensor to a specified width."""
def __init__(self, width: int = 952) -> None:
# The default is 952 because of the IAM dataset.
self.width = width
def __call__(self, image: Tensor) -> Tensor:
"""Resize tensor in the last dimension."""
return F.interpolate(image, size=self.width, mode="nearest")
class AddTokens:
"""Adds start of sequence and end of sequence tokens to target tensor."""
def __init__(self, pad_token: str, eos_token: str, init_token: str = None) -> None:
self.init_token = init_token
self.pad_token = pad_token
self.eos_token = eos_token
if self.init_token is not None:
self.emnist_mapper = EmnistMapper(
init_token=self.init_token,
pad_token=self.pad_token,
eos_token=self.eos_token,
)
else:
self.emnist_mapper = EmnistMapper(
pad_token=self.pad_token, eos_token=self.eos_token,
)
self.pad_value = self.emnist_mapper(self.pad_token)
self.eos_value = self.emnist_mapper(self.eos_token)
def __call__(self, target: Tensor) -> Tensor:
"""Adds a sos token to the begining and a eos token to the end of a target sequence."""
dtype, device = target.dtype, target.device
# Find the where padding starts.
pad_index = torch.nonzero(target == self.pad_value, as_tuple=False)[0].item()
target[pad_index] = self.eos_value
if self.init_token is not None:
self.sos_value = self.emnist_mapper(self.init_token)
sos = torch.tensor([self.sos_value], dtype=dtype, device=device)
target = torch.cat([sos, target], dim=0)
return target
class ApplyContrast:
"""Sets everything below a threshold to zero, i.e. increase contrast."""
def __init__(self, low: float = 0.0, high: float = 0.25) -> None:
self.low = low
self.high = high
def __call__(self, x: Tensor) -> Tensor:
"""Apply mask binary mask to input tensor."""
mask = x > np.random.RandomState().uniform(low=self.low, high=self.high)
return x * mask
class Unsqueeze:
"""Add a dimension to the tensor."""
def __call__(self, x: Tensor) -> Tensor:
"""Adds dim."""
return x.unsqueeze(0)
class Squeeze:
"""Removes the first dimension of a tensor."""
def __call__(self, x: Tensor) -> Tensor:
"""Removes first dim."""
return x.squeeze(0)
class ToLower:
"""Converts target to lower case."""
def __call__(self, target: Tensor) -> Tensor:
"""Corrects index value in target tensor."""
device = target.device
return torch.stack([x - 26 if x > 35 else x for x in target]).to(device)
|