summaryrefslogtreecommitdiff
path: root/text_recognizer/data/transforms/line.py
blob: e4473eb3ec87818fb4e9ea34be56e05ed189a2b4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import random
from typing import Any, Dict

import torchvision.transforms as T
from PIL import Image

import text_recognizer.metadata.iam_lines as metadata
from text_recognizer.data.transforms.image import ImageStem


class LineStem(ImageStem):
    """A stem for handling images containing a line of text."""

    def __init__(
        self,
        augment: bool = False,
        color_jitter_kwargs: Dict[str, Any] = None,
        random_affine_kwargs: Dict[str, Any] = None,
    ) -> None:
        super().__init__()
        if color_jitter_kwargs is None:
            color_jitter_kwargs = {"brightness": (0.5, 1)}
        if random_affine_kwargs is None:
            random_affine_kwargs = {
                "degrees": 3,
                "translate": (0, 0.05),
                "scale": (0.4, 1.1),
                "shear": (-40, 50),
                "interpolation": T.InterpolationMode.BILINEAR,
                "fill": 0,
            }

        if augment:
            self.pil_transforms = T.Compose(
                [
                    T.ColorJitter(**color_jitter_kwargs),
                    T.RandomAffine(**random_affine_kwargs),
                ]
            )


class IamLinesStem(ImageStem):
    """A stem for handling images containing lines of text from the IAMLines dataset."""

    def __init__(
        self,
        augment: bool = False,
        color_jitter_kwargs: Dict[str, Any] = None,
        random_affine_kwargs: Dict[str, Any] = None,
    ) -> None:
        super().__init__()

        def embed_crop(crop, augment=augment):
            # crop is PIL.image of dtype="L" (so values range from 0 -> 255)
            image = Image.new("L", (metadata.IMAGE_WIDTH, metadata.IMAGE_HEIGHT))

            # Resize crop
            crop_width, crop_height = crop.size
            new_crop_height = metadata.IMAGE_HEIGHT
            new_crop_width = int(new_crop_height * (crop_width / crop_height))
            if augment:
                # Add random stretching
                new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1))
                new_crop_width = min(new_crop_width, metadata.IMAGE_WIDTH)
            crop_resized = crop.resize(
                (new_crop_width, new_crop_height), resample=Image.BILINEAR
            )

            # Embed in the image
            x = min(metadata.CHAR_WIDTH, metadata.IMAGE_WIDTH - new_crop_width)
            y = metadata.IMAGE_HEIGHT - new_crop_height

            image.paste(crop_resized, (x, y))

            return image

        if color_jitter_kwargs is None:
            color_jitter_kwargs = {"brightness": (0.8, 1.6)}
        if random_affine_kwargs is None:
            random_affine_kwargs = {
                "degrees": 1,
                "shear": (-30, 20),
                "interpolation": T.InterpolationMode.BILINEAR,
                "fill": 0,
            }

        pil_transform_list = [T.Lambda(embed_crop)]
        if augment:
            pil_transform_list += [
                T.ColorJitter(**color_jitter_kwargs),
                T.RandomAffine(**random_affine_kwargs),
            ]
        self.pil_transform = T.Compose(pil_transform_list)