summaryrefslogtreecommitdiff
path: root/text_recognizer/data/stems/paragraph.py
blob: 39e1e5919abff057899f2e6cfcbef3d5408290cc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""Iam paragraph stem class."""
import torchvision.transforms as T

import text_recognizer.metadata.iam_paragraphs as metadata
from text_recognizer.data.stems.image import ImageStem


IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH
IMAGE_SHAPE = metadata.IMAGE_SHAPE

MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH


class ParagraphStem(ImageStem):
    """A stem for handling images that contain a paragraph of text."""

    def __init__(
        self,
        augment=False,
        color_jitter_kwargs=None,
        random_affine_kwargs=None,
        random_perspective_kwargs=None,
        gaussian_blur_kwargs=None,
        sharpness_kwargs=None,
    ):
        super().__init__()

        if not augment:
            self.pil_transform = T.Compose([T.CenterCrop(IMAGE_SHAPE)])
        else:
            if color_jitter_kwargs is None:
                color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4}
            if random_affine_kwargs is None:
                random_affine_kwargs = {
                    "degrees": 3,
                    "shear": 6,
                    "scale": (0.95, 1),
                    "interpolation": T.InterpolationMode.BILINEAR,
                }
            if random_perspective_kwargs is None:
                random_perspective_kwargs = {
                    "distortion_scale": 0.2,
                    "p": 0.5,
                    "interpolation": T.InterpolationMode.BILINEAR,
                }
            if gaussian_blur_kwargs is None:
                gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)}
            if sharpness_kwargs is None:
                sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5}

            self.pil_transform = T.Compose(
                [
                    T.ColorJitter(**color_jitter_kwargs),
                    T.RandomCrop(
                        size=IMAGE_SHAPE,
                        padding=None,
                        pad_if_needed=True,
                        fill=0,
                        padding_mode="constant",
                    ),
                    T.RandomAffine(**random_affine_kwargs),
                    T.RandomPerspective(**random_perspective_kwargs),
                    T.GaussianBlur(**gaussian_blur_kwargs),
                    T.RandomAdjustSharpness(**sharpness_kwargs),
                ]
            )