blob: 7421d0e7af1560bb77944236c9c64b1e14d7fb38 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
|
"""Transforms for PyTorch datasets."""
import random
from PIL import Image
class EmbedCrop:
IMAGE_HEIGHT = 56
IMAGE_WIDTH = 1024
def __init__(self, augment: bool) -> None:
self.augment = augment
def __call__(self, crop: Image) -> Image:
# Crop is PIL.Image of dtype="L" (so value range is [0, 255])
image = Image.new("L", (self.IMAGE_WIDTH, self.IMAGE_HEIGHT))
# Resize crop.
crop_width, crop_height = crop.size
new_crop_height = self.IMAGE_HEIGHT
new_crop_width = int(new_crop_height * (crop_width / crop_height))
if self.augment:
# Add random stretching
new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1))
new_crop_width = min(new_crop_width, self.IMAGE_WIDTH)
crop_resized = crop.resize(
(new_crop_width, new_crop_height), resample=Image.BILINEAR
)
# Embed in image
x = min(28, self.IMAGE_WIDTH - new_crop_width)
y = self.IMAGE_HEIGHT - new_crop_height
image.paste(crop_resized, (x, y))
return image
|