From 4e44486aa0e87459bed4b0fe423b16e59c76c1a0 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sun, 2 Oct 2022 03:25:28 +0200
Subject: Move stems to transforms

---
 text_recognizer/data/stems/__init__.py             |  0
 text_recognizer/data/stems/image.py                | 18 -----
 text_recognizer/data/stems/line.py                 | 93 ----------------------
 text_recognizer/data/stems/paragraph.py            | 66 ---------------
 text_recognizer/data/transforms/image.py           | 18 +++++
 text_recognizer/data/transforms/line.py            | 93 ++++++++++++++++++++++
 text_recognizer/data/transforms/paragraph.py       | 66 +++++++++++++++
 .../conf/datamodule/iam_extended_paragraphs.yaml   |  4 +-
 training/conf/datamodule/iam_lines.yaml            |  4 +-
 .../conf/experiment/conv_transformer_lines.yaml    |  3 -
 10 files changed, 181 insertions(+), 184 deletions(-)
 delete mode 100644 text_recognizer/data/stems/__init__.py
 delete mode 100644 text_recognizer/data/stems/image.py
 delete mode 100644 text_recognizer/data/stems/line.py
 delete mode 100644 text_recognizer/data/stems/paragraph.py
 create mode 100644 text_recognizer/data/transforms/image.py
 create mode 100644 text_recognizer/data/transforms/line.py
 create mode 100644 text_recognizer/data/transforms/paragraph.py

diff --git a/text_recognizer/data/stems/__init__.py b/text_recognizer/data/stems/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/text_recognizer/data/stems/image.py b/text_recognizer/data/stems/image.py
deleted file mode 100644
index f04b3a0..0000000
--- a/text_recognizer/data/stems/image.py
+++ /dev/null
@@ -1,18 +0,0 @@
-from PIL import Image
-import torch
-from torch import Tensor
-import torchvision.transforms as T
-
-
-class ImageStem:
-    def __init__(self) -> None:
-        self.pil_transform = T.Compose([])
-        self.pil_to_tensor = T.ToTensor()
-        self.torch_transform = torch.nn.Sequential()
-
-    def __call__(self, img: Image) -> Tensor:
-        img = self.pil_transform(img)
-        img = self.pil_to_tensor(img)
-        with torch.no_grad():
-            img = self.torch_transform(img)
-        return img
diff --git a/text_recognizer/data/stems/line.py b/text_recognizer/data/stems/line.py
deleted file mode 100644
index 4f0ce05..0000000
--- a/text_recognizer/data/stems/line.py
+++ /dev/null
@@ -1,93 +0,0 @@
-import random
-from typing import Any, Dict
-
-from PIL import Image
-import torchvision.transforms as T
-
-import text_recognizer.metadata.iam_lines as metadata
-from text_recognizer.data.stems.image import ImageStem
-
-
-class LineStem(ImageStem):
-    """A stem for handling images containing a line of text."""
-
-    def __init__(
-        self,
-        augment: bool = False,
-        color_jitter_kwargs: Dict[str, Any] = None,
-        random_affine_kwargs: Dict[str, Any] = None,
-    ) -> None:
-        super().__init__()
-        if color_jitter_kwargs is None:
-            color_jitter_kwargs = {"brightness": (0.5, 1)}
-        if random_affine_kwargs is None:
-            random_affine_kwargs = {
-                "degrees": 3,
-                "translate": (0, 0.05),
-                "scale": (0.4, 1.1),
-                "shear": (-40, 50),
-                "interpolation": T.InterpolationMode.BILINEAR,
-                "fill": 0,
-            }
-
-        if augment:
-            self.pil_transforms = T.Compose(
-                [
-                    T.ColorJitter(**color_jitter_kwargs),
-                    T.RandomAffine(**random_affine_kwargs),
-                ]
-            )
-
-
-class IamLinesStem(ImageStem):
-    """A stem for handling images containing lines of text from the IAMLines dataset."""
-
-    def __init__(
-        self,
-        augment: bool = False,
-        color_jitter_kwargs: Dict[str, Any] = None,
-        random_affine_kwargs: Dict[str, Any] = None,
-    ) -> None:
-        super().__init__()
-
-        def embed_crop(crop, augment=augment):
-            # crop is PIL.image of dtype="L" (so values range from 0 -> 255)
-            image = Image.new("L", (metadata.IMAGE_WIDTH, metadata.IMAGE_HEIGHT))
-
-            # Resize crop
-            crop_width, crop_height = crop.size
-            new_crop_height = metadata.IMAGE_HEIGHT
-            new_crop_width = int(new_crop_height * (crop_width / crop_height))
-            if augment:
-                # Add random stretching
-                new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1))
-                new_crop_width = min(new_crop_width, metadata.IMAGE_WIDTH)
-            crop_resized = crop.resize(
-                (new_crop_width, new_crop_height), resample=Image.BILINEAR
-            )
-
-            # Embed in the image
-            x = min(metadata.CHAR_WIDTH, metadata.IMAGE_WIDTH - new_crop_width)
-            y = metadata.IMAGE_HEIGHT - new_crop_height
-
-            image.paste(crop_resized, (x, y))
-
-            return image
-
-        if color_jitter_kwargs is None:
-            color_jitter_kwargs = {"brightness": (0.8, 1.6)}
-        if random_affine_kwargs is None:
-            random_affine_kwargs = {
-                "degrees": 1,
-                "shear": (-30, 20),
-                "interpolation": T.InterpolationMode.BILINEAR,
-                "fill": 0,
-            }
-
-        pil_transform_list = [T.Lambda(embed_crop)]
-        if augment:
-            pil_transform_list += [
-                T.ColorJitter(**color_jitter_kwargs),
-                T.RandomAffine(**random_affine_kwargs),
-            ]
-        self.pil_transform = T.Compose(pil_transform_list)
diff --git a/text_recognizer/data/stems/paragraph.py b/text_recognizer/data/stems/paragraph.py
deleted file mode 100644
index 39e1e59..0000000
--- a/text_recognizer/data/stems/paragraph.py
+++ /dev/null
@@ -1,66 +0,0 @@
-"""Iam paragraph stem class."""
-import torchvision.transforms as T
-
-import text_recognizer.metadata.iam_paragraphs as metadata
-from text_recognizer.data.stems.image import ImageStem
-
-
-IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH
-IMAGE_SHAPE = metadata.IMAGE_SHAPE
-
-MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH
-
-
-class ParagraphStem(ImageStem):
-    """A stem for handling images that contain a paragraph of text."""
-
-    def __init__(
-        self,
-        augment=False,
-        color_jitter_kwargs=None,
-        random_affine_kwargs=None,
-        random_perspective_kwargs=None,
-        gaussian_blur_kwargs=None,
-        sharpness_kwargs=None,
-    ):
-        super().__init__()
-
-        if not augment:
-            self.pil_transform = T.Compose([T.CenterCrop(IMAGE_SHAPE)])
-        else:
-            if color_jitter_kwargs is None:
-                color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4}
-            if random_affine_kwargs is None:
-                random_affine_kwargs = {
-                    "degrees": 3,
-                    "shear": 6,
-                    "scale": (0.95, 1),
-                    "interpolation": T.InterpolationMode.BILINEAR,
-                }
-            if random_perspective_kwargs is None:
-                random_perspective_kwargs = {
-                    "distortion_scale": 0.2,
-                    "p": 0.5,
-                    "interpolation": T.InterpolationMode.BILINEAR,
-                }
-            if gaussian_blur_kwargs is None:
-                gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)}
-            if sharpness_kwargs is None:
-                sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5}
-
-            self.pil_transform = T.Compose(
-                [
-                    T.ColorJitter(**color_jitter_kwargs),
-                    T.RandomCrop(
-                        size=IMAGE_SHAPE,
-                        padding=None,
-                        pad_if_needed=True,
-                        fill=0,
-                        padding_mode="constant",
-                    ),
-                    T.RandomAffine(**random_affine_kwargs),
-                    T.RandomPerspective(**random_perspective_kwargs),
-                    T.GaussianBlur(**gaussian_blur_kwargs),
-                    T.RandomAdjustSharpness(**sharpness_kwargs),
-                ]
-            )
diff --git a/text_recognizer/data/transforms/image.py b/text_recognizer/data/transforms/image.py
new file mode 100644
index 0000000..f04b3a0
--- /dev/null
+++ b/text_recognizer/data/transforms/image.py
@@ -0,0 +1,18 @@
+from PIL import Image
+import torch
+from torch import Tensor
+import torchvision.transforms as T
+
+
+class ImageStem:
+    def __init__(self) -> None:
+        self.pil_transform = T.Compose([])
+        self.pil_to_tensor = T.ToTensor()
+        self.torch_transform = torch.nn.Sequential()
+
+    def __call__(self, img: Image) -> Tensor:
+        img = self.pil_transform(img)
+        img = self.pil_to_tensor(img)
+        with torch.no_grad():
+            img = self.torch_transform(img)
+        return img
diff --git a/text_recognizer/data/transforms/line.py b/text_recognizer/data/transforms/line.py
new file mode 100644
index 0000000..4f0ce05
--- /dev/null
+++ b/text_recognizer/data/transforms/line.py
@@ -0,0 +1,93 @@
+import random
+from typing import Any, Dict
+
+from PIL import Image
+import torchvision.transforms as T
+
+import text_recognizer.metadata.iam_lines as metadata
+from text_recognizer.data.stems.image import ImageStem
+
+
+class LineStem(ImageStem):
+    """A stem for handling images containing a line of text."""
+
+    def __init__(
+        self,
+        augment: bool = False,
+        color_jitter_kwargs: Dict[str, Any] = None,
+        random_affine_kwargs: Dict[str, Any] = None,
+    ) -> None:
+        super().__init__()
+        if color_jitter_kwargs is None:
+            color_jitter_kwargs = {"brightness": (0.5, 1)}
+        if random_affine_kwargs is None:
+            random_affine_kwargs = {
+                "degrees": 3,
+                "translate": (0, 0.05),
+                "scale": (0.4, 1.1),
+                "shear": (-40, 50),
+                "interpolation": T.InterpolationMode.BILINEAR,
+                "fill": 0,
+            }
+
+        if augment:
+            self.pil_transforms = T.Compose(
+                [
+                    T.ColorJitter(**color_jitter_kwargs),
+                    T.RandomAffine(**random_affine_kwargs),
+                ]
+            )
+
+
+class IamLinesStem(ImageStem):
+    """A stem for handling images containing lines of text from the IAMLines dataset."""
+
+    def __init__(
+        self,
+        augment: bool = False,
+        color_jitter_kwargs: Dict[str, Any] = None,
+        random_affine_kwargs: Dict[str, Any] = None,
+    ) -> None:
+        super().__init__()
+
+        def embed_crop(crop, augment=augment):
+            # crop is PIL.image of dtype="L" (so values range from 0 -> 255)
+            image = Image.new("L", (metadata.IMAGE_WIDTH, metadata.IMAGE_HEIGHT))
+
+            # Resize crop
+            crop_width, crop_height = crop.size
+            new_crop_height = metadata.IMAGE_HEIGHT
+            new_crop_width = int(new_crop_height * (crop_width / crop_height))
+            if augment:
+                # Add random stretching
+                new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1))
+                new_crop_width = min(new_crop_width, metadata.IMAGE_WIDTH)
+            crop_resized = crop.resize(
+                (new_crop_width, new_crop_height), resample=Image.BILINEAR
+            )
+
+            # Embed in the image
+            x = min(metadata.CHAR_WIDTH, metadata.IMAGE_WIDTH - new_crop_width)
+            y = metadata.IMAGE_HEIGHT - new_crop_height
+
+            image.paste(crop_resized, (x, y))
+
+            return image
+
+        if color_jitter_kwargs is None:
+            color_jitter_kwargs = {"brightness": (0.8, 1.6)}
+        if random_affine_kwargs is None:
+            random_affine_kwargs = {
+                "degrees": 1,
+                "shear": (-30, 20),
+                "interpolation": T.InterpolationMode.BILINEAR,
+                "fill": 0,
+            }
+
+        pil_transform_list = [T.Lambda(embed_crop)]
+        if augment:
+            pil_transform_list += [
+                T.ColorJitter(**color_jitter_kwargs),
+                T.RandomAffine(**random_affine_kwargs),
+            ]
+        self.pil_transform = T.Compose(pil_transform_list)
diff --git a/text_recognizer/data/transforms/paragraph.py b/text_recognizer/data/transforms/paragraph.py
new file mode 100644
index 0000000..39e1e59
--- /dev/null
+++ b/text_recognizer/data/transforms/paragraph.py
@@ -0,0 +1,66 @@
+"""Iam paragraph stem class."""
+import torchvision.transforms as T
+
+import text_recognizer.metadata.iam_paragraphs as metadata
+from text_recognizer.data.stems.image import ImageStem
+
+
+IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH
+IMAGE_SHAPE = metadata.IMAGE_SHAPE
+
+MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH
+
+
+class ParagraphStem(ImageStem):
+    """A stem for handling images that contain a paragraph of text."""
+
+    def __init__(
+        self,
+        augment=False,
+        color_jitter_kwargs=None,
+        random_affine_kwargs=None,
+        random_perspective_kwargs=None,
+        gaussian_blur_kwargs=None,
+        sharpness_kwargs=None,
+    ):
+        super().__init__()
+
+        if not augment:
+            self.pil_transform = T.Compose([T.CenterCrop(IMAGE_SHAPE)])
+        else:
+            if color_jitter_kwargs is None:
+                color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4}
+            if random_affine_kwargs is None:
+                random_affine_kwargs = {
+                    "degrees": 3,
+                    "shear": 6,
+                    "scale": (0.95, 1),
+                    "interpolation": T.InterpolationMode.BILINEAR,
+                }
+            if random_perspective_kwargs is None:
+                random_perspective_kwargs = {
+                    "distortion_scale": 0.2,
+                    "p": 0.5,
+                    "interpolation": T.InterpolationMode.BILINEAR,
+                }
+            if gaussian_blur_kwargs is None:
+                gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)}
+            if sharpness_kwargs is None:
+                sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5}
+
+            self.pil_transform = T.Compose(
+                [
+                    T.ColorJitter(**color_jitter_kwargs),
+                    T.RandomCrop(
+                        size=IMAGE_SHAPE,
+                        padding=None,
+                        pad_if_needed=True,
+                        fill=0,
+                        padding_mode="constant",
+                    ),
+                    T.RandomAffine(**random_affine_kwargs),
+                    T.RandomPerspective(**random_perspective_kwargs),
+                    T.GaussianBlur(**gaussian_blur_kwargs),
+                    T.RandomAdjustSharpness(**sharpness_kwargs),
+                ]
+            )
diff --git a/training/conf/datamodule/iam_extended_paragraphs.yaml b/training/conf/datamodule/iam_extended_paragraphs.yaml
index 64c3964..e4ef896 100644
--- a/training/conf/datamodule/iam_extended_paragraphs.yaml
+++ b/training/conf/datamodule/iam_extended_paragraphs.yaml
@@ -4,10 +4,10 @@ num_workers: 12
 train_fraction: 0.8
 pin_memory: true
 transform:
-  _target_: text_recognizer.data.stems.paragraph.ParagraphStem
+  _target_: text_recognizer.data.transforms.paragraph.ParagraphStem
   augment: true
 test_transform:
-  _target_: text_recognizer.data.stems.paragraph.ParagraphStem
+  _target_: text_recognizer.data.transforms.paragraph.ParagraphStem
   augment: false
 target_transform:
   _target_: text_recognizer.data.transforms.pad.Pad
diff --git a/training/conf/datamodule/iam_lines.yaml b/training/conf/datamodule/iam_lines.yaml
index f84116d..1205c75 100644
--- a/training/conf/datamodule/iam_lines.yaml
+++ b/training/conf/datamodule/iam_lines.yaml
@@ -4,10 +4,10 @@ num_workers: 12
 train_fraction: 0.9
 pin_memory: true
 transform:
-  _target_: text_recognizer.data.stems.line.IamLinesStem
+  _target_: text_recognizer.data.transforms.line.IamLinesStem
   augment: true
 test_transform:
-  _target_: text_recognizer.data.stems.line.IamLinesStem
+  _target_: text_recognizer.data.transforms.line.IamLinesStem
   augment: false
 tokenizer:
   _target_: text_recognizer.data.tokenizer.Tokenizer
diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml
index 3f5da86..948968a 100644
--- a/training/conf/experiment/conv_transformer_lines.yaml
+++ b/training/conf/experiment/conv_transformer_lines.yaml
@@ -54,9 +54,6 @@ lr_scheduler:
 datamodule:
   batch_size: 16
   train_fraction: 0.95
-  transform:
-    _target_: text_recognizer.data.stems.line.IamLinesStem
-    augment: false
 
 network:
   _target_: text_recognizer.networks.ConvTransformer
-- 
cgit v1.2.3-70-g09d2