From 4e44486aa0e87459bed4b0fe423b16e59c76c1a0 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 2 Oct 2022 03:25:28 +0200 Subject: Move stems to transforms --- text_recognizer/data/stems/__init__.py | 0 text_recognizer/data/stems/image.py | 18 ----- text_recognizer/data/stems/line.py | 93 ---------------------- text_recognizer/data/stems/paragraph.py | 66 --------------- text_recognizer/data/transforms/image.py | 18 +++++ text_recognizer/data/transforms/line.py | 93 ++++++++++++++++++++++ text_recognizer/data/transforms/paragraph.py | 66 +++++++++++++++ .../conf/datamodule/iam_extended_paragraphs.yaml | 4 +- training/conf/datamodule/iam_lines.yaml | 4 +- .../conf/experiment/conv_transformer_lines.yaml | 3 - 10 files changed, 181 insertions(+), 184 deletions(-) delete mode 100644 text_recognizer/data/stems/__init__.py delete mode 100644 text_recognizer/data/stems/image.py delete mode 100644 text_recognizer/data/stems/line.py delete mode 100644 text_recognizer/data/stems/paragraph.py create mode 100644 text_recognizer/data/transforms/image.py create mode 100644 text_recognizer/data/transforms/line.py create mode 100644 text_recognizer/data/transforms/paragraph.py diff --git a/text_recognizer/data/stems/__init__.py b/text_recognizer/data/stems/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/text_recognizer/data/stems/image.py b/text_recognizer/data/stems/image.py deleted file mode 100644 index f04b3a0..0000000 --- a/text_recognizer/data/stems/image.py +++ /dev/null @@ -1,18 +0,0 @@ -from PIL import Image -import torch -from torch import Tensor -import torchvision.transforms as T - - -class ImageStem: - def __init__(self) -> None: - self.pil_transform = T.Compose([]) - self.pil_to_tensor = T.ToTensor() - self.torch_transform = torch.nn.Sequential() - - def __call__(self, img: Image) -> Tensor: - img = self.pil_transform(img) - img = self.pil_to_tensor(img) - with torch.no_grad(): - img = self.torch_transform(img) - return img diff --git a/text_recognizer/data/stems/line.py b/text_recognizer/data/stems/line.py deleted file mode 100644 index 4f0ce05..0000000 --- a/text_recognizer/data/stems/line.py +++ /dev/null @@ -1,93 +0,0 @@ -import random -from typing import Any, Dict - -from PIL import Image -import torchvision.transforms as T - -import text_recognizer.metadata.iam_lines as metadata -from text_recognizer.data.stems.image import ImageStem - - -class LineStem(ImageStem): - """A stem for handling images containing a line of text.""" - - def __init__( - self, - augment: bool = False, - color_jitter_kwargs: Dict[str, Any] = None, - random_affine_kwargs: Dict[str, Any] = None, - ) -> None: - super().__init__() - if color_jitter_kwargs is None: - color_jitter_kwargs = {"brightness": (0.5, 1)} - if random_affine_kwargs is None: - random_affine_kwargs = { - "degrees": 3, - "translate": (0, 0.05), - "scale": (0.4, 1.1), - "shear": (-40, 50), - "interpolation": T.InterpolationMode.BILINEAR, - "fill": 0, - } - - if augment: - self.pil_transforms = T.Compose( - [ - T.ColorJitter(**color_jitter_kwargs), - T.RandomAffine(**random_affine_kwargs), - ] - ) - - -class IamLinesStem(ImageStem): - """A stem for handling images containing lines of text from the IAMLines dataset.""" - - def __init__( - self, - augment: bool = False, - color_jitter_kwargs: Dict[str, Any] = None, - random_affine_kwargs: Dict[str, Any] = None, - ) -> None: - super().__init__() - - def embed_crop(crop, augment=augment): - # crop is PIL.image of dtype="L" (so values range from 0 -> 255) - image = Image.new("L", (metadata.IMAGE_WIDTH, metadata.IMAGE_HEIGHT)) - - # Resize crop - crop_width, crop_height = crop.size - new_crop_height = metadata.IMAGE_HEIGHT - new_crop_width = int(new_crop_height * (crop_width / crop_height)) - if augment: - # Add random stretching - new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1)) - new_crop_width = min(new_crop_width, metadata.IMAGE_WIDTH) - crop_resized = crop.resize( - (new_crop_width, new_crop_height), resample=Image.BILINEAR - ) - - # Embed in the image - x = min(metadata.CHAR_WIDTH, metadata.IMAGE_WIDTH - new_crop_width) - y = metadata.IMAGE_HEIGHT - new_crop_height - - image.paste(crop_resized, (x, y)) - - return image - - if color_jitter_kwargs is None: - color_jitter_kwargs = {"brightness": (0.8, 1.6)} - if random_affine_kwargs is None: - random_affine_kwargs = { - "degrees": 1, - "shear": (-30, 20), - "interpolation": T.InterpolationMode.BILINEAR, - "fill": 0, - } - - pil_transform_list = [T.Lambda(embed_crop)] - if augment: - pil_transform_list += [ - T.ColorJitter(**color_jitter_kwargs), - T.RandomAffine(**random_affine_kwargs), - ] - self.pil_transform = T.Compose(pil_transform_list) diff --git a/text_recognizer/data/stems/paragraph.py b/text_recognizer/data/stems/paragraph.py deleted file mode 100644 index 39e1e59..0000000 --- a/text_recognizer/data/stems/paragraph.py +++ /dev/null @@ -1,66 +0,0 @@ -"""Iam paragraph stem class.""" -import torchvision.transforms as T - -import text_recognizer.metadata.iam_paragraphs as metadata -from text_recognizer.data.stems.image import ImageStem - - -IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH -IMAGE_SHAPE = metadata.IMAGE_SHAPE - -MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH - - -class ParagraphStem(ImageStem): - """A stem for handling images that contain a paragraph of text.""" - - def __init__( - self, - augment=False, - color_jitter_kwargs=None, - random_affine_kwargs=None, - random_perspective_kwargs=None, - gaussian_blur_kwargs=None, - sharpness_kwargs=None, - ): - super().__init__() - - if not augment: - self.pil_transform = T.Compose([T.CenterCrop(IMAGE_SHAPE)]) - else: - if color_jitter_kwargs is None: - color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4} - if random_affine_kwargs is None: - random_affine_kwargs = { - "degrees": 3, - "shear": 6, - "scale": (0.95, 1), - "interpolation": T.InterpolationMode.BILINEAR, - } - if random_perspective_kwargs is None: - random_perspective_kwargs = { - "distortion_scale": 0.2, - "p": 0.5, - "interpolation": T.InterpolationMode.BILINEAR, - } - if gaussian_blur_kwargs is None: - gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)} - if sharpness_kwargs is None: - sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5} - - self.pil_transform = T.Compose( - [ - T.ColorJitter(**color_jitter_kwargs), - T.RandomCrop( - size=IMAGE_SHAPE, - padding=None, - pad_if_needed=True, - fill=0, - padding_mode="constant", - ), - T.RandomAffine(**random_affine_kwargs), - T.RandomPerspective(**random_perspective_kwargs), - T.GaussianBlur(**gaussian_blur_kwargs), - T.RandomAdjustSharpness(**sharpness_kwargs), - ] - ) diff --git a/text_recognizer/data/transforms/image.py b/text_recognizer/data/transforms/image.py new file mode 100644 index 0000000..f04b3a0 --- /dev/null +++ b/text_recognizer/data/transforms/image.py @@ -0,0 +1,18 @@ +from PIL import Image +import torch +from torch import Tensor +import torchvision.transforms as T + + +class ImageStem: + def __init__(self) -> None: + self.pil_transform = T.Compose([]) + self.pil_to_tensor = T.ToTensor() + self.torch_transform = torch.nn.Sequential() + + def __call__(self, img: Image) -> Tensor: + img = self.pil_transform(img) + img = self.pil_to_tensor(img) + with torch.no_grad(): + img = self.torch_transform(img) + return img diff --git a/text_recognizer/data/transforms/line.py b/text_recognizer/data/transforms/line.py new file mode 100644 index 0000000..4f0ce05 --- /dev/null +++ b/text_recognizer/data/transforms/line.py @@ -0,0 +1,93 @@ +import random +from typing import Any, Dict + +from PIL import Image +import torchvision.transforms as T + +import text_recognizer.metadata.iam_lines as metadata +from text_recognizer.data.stems.image import ImageStem + + +class LineStem(ImageStem): + """A stem for handling images containing a line of text.""" + + def __init__( + self, + augment: bool = False, + color_jitter_kwargs: Dict[str, Any] = None, + random_affine_kwargs: Dict[str, Any] = None, + ) -> None: + super().__init__() + if color_jitter_kwargs is None: + color_jitter_kwargs = {"brightness": (0.5, 1)} + if random_affine_kwargs is None: + random_affine_kwargs = { + "degrees": 3, + "translate": (0, 0.05), + "scale": (0.4, 1.1), + "shear": (-40, 50), + "interpolation": T.InterpolationMode.BILINEAR, + "fill": 0, + } + + if augment: + self.pil_transforms = T.Compose( + [ + T.ColorJitter(**color_jitter_kwargs), + T.RandomAffine(**random_affine_kwargs), + ] + ) + + +class IamLinesStem(ImageStem): + """A stem for handling images containing lines of text from the IAMLines dataset.""" + + def __init__( + self, + augment: bool = False, + color_jitter_kwargs: Dict[str, Any] = None, + random_affine_kwargs: Dict[str, Any] = None, + ) -> None: + super().__init__() + + def embed_crop(crop, augment=augment): + # crop is PIL.image of dtype="L" (so values range from 0 -> 255) + image = Image.new("L", (metadata.IMAGE_WIDTH, metadata.IMAGE_HEIGHT)) + + # Resize crop + crop_width, crop_height = crop.size + new_crop_height = metadata.IMAGE_HEIGHT + new_crop_width = int(new_crop_height * (crop_width / crop_height)) + if augment: + # Add random stretching + new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1)) + new_crop_width = min(new_crop_width, metadata.IMAGE_WIDTH) + crop_resized = crop.resize( + (new_crop_width, new_crop_height), resample=Image.BILINEAR + ) + + # Embed in the image + x = min(metadata.CHAR_WIDTH, metadata.IMAGE_WIDTH - new_crop_width) + y = metadata.IMAGE_HEIGHT - new_crop_height + + image.paste(crop_resized, (x, y)) + + return image + + if color_jitter_kwargs is None: + color_jitter_kwargs = {"brightness": (0.8, 1.6)} + if random_affine_kwargs is None: + random_affine_kwargs = { + "degrees": 1, + "shear": (-30, 20), + "interpolation": T.InterpolationMode.BILINEAR, + "fill": 0, + } + + pil_transform_list = [T.Lambda(embed_crop)] + if augment: + pil_transform_list += [ + T.ColorJitter(**color_jitter_kwargs), + T.RandomAffine(**random_affine_kwargs), + ] + self.pil_transform = T.Compose(pil_transform_list) diff --git a/text_recognizer/data/transforms/paragraph.py b/text_recognizer/data/transforms/paragraph.py new file mode 100644 index 0000000..39e1e59 --- /dev/null +++ b/text_recognizer/data/transforms/paragraph.py @@ -0,0 +1,66 @@ +"""Iam paragraph stem class.""" +import torchvision.transforms as T + +import text_recognizer.metadata.iam_paragraphs as metadata +from text_recognizer.data.stems.image import ImageStem + + +IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH +IMAGE_SHAPE = metadata.IMAGE_SHAPE + +MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH + + +class ParagraphStem(ImageStem): + """A stem for handling images that contain a paragraph of text.""" + + def __init__( + self, + augment=False, + color_jitter_kwargs=None, + random_affine_kwargs=None, + random_perspective_kwargs=None, + gaussian_blur_kwargs=None, + sharpness_kwargs=None, + ): + super().__init__() + + if not augment: + self.pil_transform = T.Compose([T.CenterCrop(IMAGE_SHAPE)]) + else: + if color_jitter_kwargs is None: + color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4} + if random_affine_kwargs is None: + random_affine_kwargs = { + "degrees": 3, + "shear": 6, + "scale": (0.95, 1), + "interpolation": T.InterpolationMode.BILINEAR, + } + if random_perspective_kwargs is None: + random_perspective_kwargs = { + "distortion_scale": 0.2, + "p": 0.5, + "interpolation": T.InterpolationMode.BILINEAR, + } + if gaussian_blur_kwargs is None: + gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)} + if sharpness_kwargs is None: + sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5} + + self.pil_transform = T.Compose( + [ + T.ColorJitter(**color_jitter_kwargs), + T.RandomCrop( + size=IMAGE_SHAPE, + padding=None, + pad_if_needed=True, + fill=0, + padding_mode="constant", + ), + T.RandomAffine(**random_affine_kwargs), + T.RandomPerspective(**random_perspective_kwargs), + T.GaussianBlur(**gaussian_blur_kwargs), + T.RandomAdjustSharpness(**sharpness_kwargs), + ] + ) diff --git a/training/conf/datamodule/iam_extended_paragraphs.yaml b/training/conf/datamodule/iam_extended_paragraphs.yaml index 64c3964..e4ef896 100644 --- a/training/conf/datamodule/iam_extended_paragraphs.yaml +++ b/training/conf/datamodule/iam_extended_paragraphs.yaml @@ -4,10 +4,10 @@ num_workers: 12 train_fraction: 0.8 pin_memory: true transform: - _target_: text_recognizer.data.stems.paragraph.ParagraphStem + _target_: text_recognizer.data.transforms.paragraph.ParagraphStem augment: true test_transform: - _target_: text_recognizer.data.stems.paragraph.ParagraphStem + _target_: text_recognizer.data.transforms.paragraph.ParagraphStem augment: false target_transform: _target_: text_recognizer.data.transforms.pad.Pad diff --git a/training/conf/datamodule/iam_lines.yaml b/training/conf/datamodule/iam_lines.yaml index f84116d..1205c75 100644 --- a/training/conf/datamodule/iam_lines.yaml +++ b/training/conf/datamodule/iam_lines.yaml @@ -4,10 +4,10 @@ num_workers: 12 train_fraction: 0.9 pin_memory: true transform: - _target_: text_recognizer.data.stems.line.IamLinesStem + _target_: text_recognizer.data.transforms.line.IamLinesStem augment: true test_transform: - _target_: text_recognizer.data.stems.line.IamLinesStem + _target_: text_recognizer.data.transforms.line.IamLinesStem augment: false tokenizer: _target_: text_recognizer.data.tokenizer.Tokenizer diff --git a/training/conf/experiment/conv_transformer_lines.yaml b/training/conf/experiment/conv_transformer_lines.yaml index 3f5da86..948968a 100644 --- a/training/conf/experiment/conv_transformer_lines.yaml +++ b/training/conf/experiment/conv_transformer_lines.yaml @@ -54,9 +54,6 @@ lr_scheduler: datamodule: batch_size: 16 train_fraction: 0.95 - transform: - _target_: text_recognizer.data.stems.line.IamLinesStem - augment: false network: _target_: text_recognizer.networks.ConvTransformer -- cgit v1.2.3-70-g09d2