summaryrefslogtreecommitdiff
path: root/text_recognizer/data/transforms/embed_crop.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:05:24 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:05:24 +0200
commit38f546f0b86fc0dc89863b00c5ee8c6685394ef2 (patch)
treedeaf5501c8e687e235add858ad92e068df7d3615 /text_recognizer/data/transforms/embed_crop.py
parent8291a87c64f9a5f18caec82201bea15579b49730 (diff)
Add custom transforms
Diffstat (limited to 'text_recognizer/data/transforms/embed_crop.py')
-rw-r--r--text_recognizer/data/transforms/embed_crop.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/text_recognizer/data/transforms/embed_crop.py b/text_recognizer/data/transforms/embed_crop.py
new file mode 100644
index 0000000..7421d0e
--- /dev/null
+++ b/text_recognizer/data/transforms/embed_crop.py
@@ -0,0 +1,37 @@
+"""Transforms for PyTorch datasets."""
+import random
+
+from PIL import Image
+
+
+class EmbedCrop:
+
+ IMAGE_HEIGHT = 56
+ IMAGE_WIDTH = 1024
+
+ def __init__(self, augment: bool) -> None:
+ self.augment = augment
+
+ def __call__(self, crop: Image) -> Image:
+ # Crop is PIL.Image of dtype="L" (so value range is [0, 255])
+ image = Image.new("L", (self.IMAGE_WIDTH, self.IMAGE_HEIGHT))
+
+ # Resize crop.
+ crop_width, crop_height = crop.size
+ new_crop_height = self.IMAGE_HEIGHT
+ new_crop_width = int(new_crop_height * (crop_width / crop_height))
+
+ if self.augment:
+ # Add random stretching
+ new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1))
+ new_crop_width = min(new_crop_width, self.IMAGE_WIDTH)
+ crop_resized = crop.resize(
+ (new_crop_width, new_crop_height), resample=Image.BILINEAR
+ )
+
+ # Embed in image
+ x = min(28, self.IMAGE_WIDTH - new_crop_width)
+ y = self.IMAGE_HEIGHT - new_crop_height
+ image.paste(crop_resized, (x, y))
+
+ return image