diff options
Diffstat (limited to 'text_recognizer/data/transforms/embed_crop.py')
-rw-r--r-- | text_recognizer/data/transforms/embed_crop.py | 37 |
1 files changed, 37 insertions, 0 deletions
diff --git a/text_recognizer/data/transforms/embed_crop.py b/text_recognizer/data/transforms/embed_crop.py new file mode 100644 index 0000000..7421d0e --- /dev/null +++ b/text_recognizer/data/transforms/embed_crop.py @@ -0,0 +1,37 @@ +"""Transforms for PyTorch datasets.""" +import random + +from PIL import Image + + +class EmbedCrop: + + IMAGE_HEIGHT = 56 + IMAGE_WIDTH = 1024 + + def __init__(self, augment: bool) -> None: + self.augment = augment + + def __call__(self, crop: Image) -> Image: + # Crop is PIL.Image of dtype="L" (so value range is [0, 255]) + image = Image.new("L", (self.IMAGE_WIDTH, self.IMAGE_HEIGHT)) + + # Resize crop. + crop_width, crop_height = crop.size + new_crop_height = self.IMAGE_HEIGHT + new_crop_width = int(new_crop_height * (crop_width / crop_height)) + + if self.augment: + # Add random stretching + new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1)) + new_crop_width = min(new_crop_width, self.IMAGE_WIDTH) + crop_resized = crop.resize( + (new_crop_width, new_crop_height), resample=Image.BILINEAR + ) + + # Embed in image + x = min(28, self.IMAGE_WIDTH - new_crop_width) + y = self.IMAGE_HEIGHT - new_crop_height + image.paste(crop_resized, (x, y)) + + return image |