summaryrefslogtreecommitdiff
path: root/text_recognizer/data/transforms
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-10-02 03:25:28 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-10-02 03:25:28 +0200
commit4e44486aa0e87459bed4b0fe423b16e59c76c1a0 (patch)
treea0ce9f72d9000ac8e0959f159a66a7f1dbcf8892 /text_recognizer/data/transforms
parent1e0378e1ba1cdab3c064473ef951b97515f28947 (diff)
Move stems to transforms
Diffstat (limited to 'text_recognizer/data/transforms')
-rw-r--r--text_recognizer/data/transforms/image.py18
-rw-r--r--text_recognizer/data/transforms/line.py93
-rw-r--r--text_recognizer/data/transforms/paragraph.py66
3 files changed, 177 insertions, 0 deletions
diff --git a/text_recognizer/data/transforms/image.py b/text_recognizer/data/transforms/image.py
new file mode 100644
index 0000000..f04b3a0
--- /dev/null
+++ b/text_recognizer/data/transforms/image.py
@@ -0,0 +1,18 @@
+from PIL import Image
+import torch
+from torch import Tensor
+import torchvision.transforms as T
+
+
+class ImageStem:
+ def __init__(self) -> None:
+ self.pil_transform = T.Compose([])
+ self.pil_to_tensor = T.ToTensor()
+ self.torch_transform = torch.nn.Sequential()
+
+ def __call__(self, img: Image) -> Tensor:
+ img = self.pil_transform(img)
+ img = self.pil_to_tensor(img)
+ with torch.no_grad():
+ img = self.torch_transform(img)
+ return img
diff --git a/text_recognizer/data/transforms/line.py b/text_recognizer/data/transforms/line.py
new file mode 100644
index 0000000..4f0ce05
--- /dev/null
+++ b/text_recognizer/data/transforms/line.py
@@ -0,0 +1,93 @@
+import random
+from typing import Any, Dict
+
+from PIL import Image
+import torchvision.transforms as T
+
+import text_recognizer.metadata.iam_lines as metadata
+from text_recognizer.data.stems.image import ImageStem
+
+
+class LineStem(ImageStem):
+ """A stem for handling images containing a line of text."""
+
+ def __init__(
+ self,
+ augment: bool = False,
+ color_jitter_kwargs: Dict[str, Any] = None,
+ random_affine_kwargs: Dict[str, Any] = None,
+ ) -> None:
+ super().__init__()
+ if color_jitter_kwargs is None:
+ color_jitter_kwargs = {"brightness": (0.5, 1)}
+ if random_affine_kwargs is None:
+ random_affine_kwargs = {
+ "degrees": 3,
+ "translate": (0, 0.05),
+ "scale": (0.4, 1.1),
+ "shear": (-40, 50),
+ "interpolation": T.InterpolationMode.BILINEAR,
+ "fill": 0,
+ }
+
+ if augment:
+ self.pil_transforms = T.Compose(
+ [
+ T.ColorJitter(**color_jitter_kwargs),
+ T.RandomAffine(**random_affine_kwargs),
+ ]
+ )
+
+
+class IamLinesStem(ImageStem):
+ """A stem for handling images containing lines of text from the IAMLines dataset."""
+
+ def __init__(
+ self,
+ augment: bool = False,
+ color_jitter_kwargs: Dict[str, Any] = None,
+ random_affine_kwargs: Dict[str, Any] = None,
+ ) -> None:
+ super().__init__()
+
+ def embed_crop(crop, augment=augment):
+ # crop is PIL.image of dtype="L" (so values range from 0 -> 255)
+ image = Image.new("L", (metadata.IMAGE_WIDTH, metadata.IMAGE_HEIGHT))
+
+ # Resize crop
+ crop_width, crop_height = crop.size
+ new_crop_height = metadata.IMAGE_HEIGHT
+ new_crop_width = int(new_crop_height * (crop_width / crop_height))
+ if augment:
+ # Add random stretching
+ new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1))
+ new_crop_width = min(new_crop_width, metadata.IMAGE_WIDTH)
+ crop_resized = crop.resize(
+ (new_crop_width, new_crop_height), resample=Image.BILINEAR
+ )
+
+ # Embed in the image
+ x = min(metadata.CHAR_WIDTH, metadata.IMAGE_WIDTH - new_crop_width)
+ y = metadata.IMAGE_HEIGHT - new_crop_height
+
+ image.paste(crop_resized, (x, y))
+
+ return image
+
+ if color_jitter_kwargs is None:
+ color_jitter_kwargs = {"brightness": (0.8, 1.6)}
+ if random_affine_kwargs is None:
+ random_affine_kwargs = {
+ "degrees": 1,
+ "shear": (-30, 20),
+ "interpolation": T.InterpolationMode.BILINEAR,
+ "fill": 0,
+ }
+
+ pil_transform_list = [T.Lambda(embed_crop)]
+ if augment:
+ pil_transform_list += [
+ T.ColorJitter(**color_jitter_kwargs),
+ T.RandomAffine(**random_affine_kwargs),
+ ]
+ self.pil_transform = T.Compose(pil_transform_list)
diff --git a/text_recognizer/data/transforms/paragraph.py b/text_recognizer/data/transforms/paragraph.py
new file mode 100644
index 0000000..39e1e59
--- /dev/null
+++ b/text_recognizer/data/transforms/paragraph.py
@@ -0,0 +1,66 @@
+"""Iam paragraph stem class."""
+import torchvision.transforms as T
+
+import text_recognizer.metadata.iam_paragraphs as metadata
+from text_recognizer.data.stems.image import ImageStem
+
+
+IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH
+IMAGE_SHAPE = metadata.IMAGE_SHAPE
+
+MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH
+
+
+class ParagraphStem(ImageStem):
+ """A stem for handling images that contain a paragraph of text."""
+
+ def __init__(
+ self,
+ augment=False,
+ color_jitter_kwargs=None,
+ random_affine_kwargs=None,
+ random_perspective_kwargs=None,
+ gaussian_blur_kwargs=None,
+ sharpness_kwargs=None,
+ ):
+ super().__init__()
+
+ if not augment:
+ self.pil_transform = T.Compose([T.CenterCrop(IMAGE_SHAPE)])
+ else:
+ if color_jitter_kwargs is None:
+ color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4}
+ if random_affine_kwargs is None:
+ random_affine_kwargs = {
+ "degrees": 3,
+ "shear": 6,
+ "scale": (0.95, 1),
+ "interpolation": T.InterpolationMode.BILINEAR,
+ }
+ if random_perspective_kwargs is None:
+ random_perspective_kwargs = {
+ "distortion_scale": 0.2,
+ "p": 0.5,
+ "interpolation": T.InterpolationMode.BILINEAR,
+ }
+ if gaussian_blur_kwargs is None:
+ gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)}
+ if sharpness_kwargs is None:
+ sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5}
+
+ self.pil_transform = T.Compose(
+ [
+ T.ColorJitter(**color_jitter_kwargs),
+ T.RandomCrop(
+ size=IMAGE_SHAPE,
+ padding=None,
+ pad_if_needed=True,
+ fill=0,
+ padding_mode="constant",
+ ),
+ T.RandomAffine(**random_affine_kwargs),
+ T.RandomPerspective(**random_perspective_kwargs),
+ T.GaussianBlur(**gaussian_blur_kwargs),
+ T.RandomAdjustSharpness(**sharpness_kwargs),
+ ]
+ )