summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_lines.py
blob: 391075ab2672a35d744f3862dc26d53d486a9332 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
"""Class for IAM Lines dataset.

If not created, will generate a handwritten lines dataset from the IAM paragraphs
dataset.

"""
import json
from pathlib import Path
import random
from typing import List, Sequence, Tuple

from loguru import logger
from PIL import Image, ImageFile, ImageOps
import numpy as np
from torch import Tensor
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode

from text_recognizer.data.base_dataset import (
    BaseDataset,
    convert_strings_to_labels,
    split_dataset,
)
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.emnist import emnist_mapping
from text_recognizer.data.iam import IAM
from text_recognizer.data import image_utils


ImageFile.LOAD_TRUNCATED_IMAGES = True

SEED = 4711
PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_lines"
IMAGE_HEIGHT = 56
IMAGE_WIDTH = 1024


class IAMLines(BaseDataModule):
    """IAM handwritten lines dataset."""

    def __init__(
        self,
        augment: bool = True,
        fraction: float = 0.8,
        batch_size: int = 128,
        num_workers: int = 0,
    ) -> None:
        # TODO: add transforms
        super().__init__(batch_size, num_workers)
        self.augment = augment
        self.fraction = fraction
        self.mapping, self.inverse_mapping, _ = emnist_mapping()
        self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
        self.output_dims = (89, 1)
        self.data_train: BaseDataset = None
        self.data_val: BaseDataset = None
        self.data_test: BaseDataset = None

    def prepare_data(self) -> None:
        """Creates the IAM lines dataset if not existing."""
        if PROCESSED_DATA_DIRNAME.exists():
            return

        logger.info("Cropping IAM lines regions...")
        iam = IAM()
        iam.prepare_data()
        crops_train, labels_train = line_crops_and_labels(iam, "train")
        crops_test, labels_test = line_crops_and_labels(iam, "test")

        shapes = np.array([crop.size for crop in crops_train + crops_test])
        aspect_ratios = shapes[:, 0] / shapes[:, 1]

        logger.info("Saving images, labels, and statistics...")
        save_images_and_labels(
            crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME
        )
        save_images_and_labels(crops_test, labels_test, "test", PROCESSED_DATA_DIRNAME)

        with (PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(mode="w") as f:
            f.write(str(aspect_ratios.max()))

    def setup(self, stage: str = None) -> None:
        """Load data for training/testing."""
        with (PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(mode="r") as f:
            max_aspect_ratio = float(f.read())
            image_width = int(IMAGE_HEIGHT * max_aspect_ratio)
            if image_width >= IMAGE_WIDTH:
                raise ValueError("image_width equal or greater than IMAGE_WIDTH")

        if stage == "fit" or stage is None:
            x_train, labels_train = load_line_crops_and_labels(
                "train", PROCESSED_DATA_DIRNAME
            )
            if self.output_dims[0] < max([len(l) for l in labels_train]) + 2:
                raise ValueError("Target length longer than max output length.")

            y_train = convert_strings_to_labels(
                labels_train, self.inverse_mapping, length=self.output_dims[0]
            )
            data_train = BaseDataset(
                x_train, y_train, transform=get_transform(IMAGE_WIDTH, self.augment)
            )

            self.data_train, self.data_val = split_dataset(
                dataset=data_train, fraction=self.fraction, seed=SEED
            )

        if stage == "test" or stage is None:
            x_test, labels_test = load_line_crops_and_labels(
                "test", PROCESSED_DATA_DIRNAME
            )

            if self.output_dims[0] < max([len(l) for l in labels_test]) + 2:
                raise ValueError("Taget length longer than max output length.")

            y_test = convert_strings_to_labels(
                labels_test, self.inverse_mapping, length=self.output_dims[0]
            )
            self.data_test = BaseDataset(
                x_test, y_test, transform=get_transform(IMAGE_WIDTH)
            )

        if stage is None:
            self._verify_output_dims(labels_train, labels_test)

    def _verify_output_dims(self, labels_train: Tensor, labels_test: Tensor) -> None:
        max_label_length = max([len(label) for label in labels_train + labels_test]) + 2
        output_dims = (max_label_length, 1)
        if output_dims != self.output_dims:
            raise ValueError("Output dim does not match expected output dims.")

    def __repr__(self) -> str:
        """Return information about the dataset."""
        basic = (
            "IAM Lines dataset\n"
            f"Num classes: {len(self.mapping)}\n"
            f"Input dims: {self.dims}\n"
            f"Output dims: {self.output_dims}\n"
        )

        if not any([self.data_train, self.data_val, self.data_test]):
            return basic

        x, y = next(iter(self.train_dataloader()))
        xt, yt = next(iter(self.test_dataloader()))
        data = (
            "Train/val/test sizes: "
            f"{len(self.data_train)}, "
            f"{len(self.data_val)}, "
            f"{len(self.data_test)}\n"
            f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
            f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
            f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n"
            f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n"
        )
        return basic + data


def line_crops_and_labels(iam: IAM, split: str) -> Tuple[List, List]:
    """Load IAM line labels and regions, and load image crops."""
    crops = []
    labels = []
    for filename in iam.form_filenames:
        if not iam.split_by_id[filename.stem] == split:
            continue
        image = image_utils.read_image_pil(filename)
        image = ImageOps.grayscale(image)
        image = ImageOps.invert(image)
        labels += iam.line_strings_by_id[filename.stem]
        crops += [
            image.crop([region[box] for box in ["x1", "y1", "x2", "y2"]])
            for region in iam.line_regions_by_id[filename.stem]
        ]
    if len(crops) != len(labels):
        raise ValueError("Length of crops does not match length of labels")
    return crops, labels


def save_images_and_labels(
    crops: Sequence[Image.Image], labels: Sequence[str], split: str, data_dirname: Path
) -> None:
    (data_dirname / split).mkdir(parents=True, exist_ok=True)

    with (data_dirname / split / "_labels.json").open(mode="w") as f:
        json.dump(labels, f)

    for index, crop in enumerate(crops):
        crop.save(data_dirname / split / f"{index}.png")


def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, List]:
    """Load line crops and labels for given split from processed directoru."""
    with (data_dirname / split / "_labels.json").open(mode="r") as f:
        labels = json.load(f)

    crop_filenames = sorted(
        (data_dirname / split).glob("*.png"),
        key=lambda filename: int(Path(filename).stem),
    )
    crops = [
        image_utils.read_image_pil(filename, grayscale=True)
        for filename in crop_filenames
    ]

    if len(crops) != len(labels):
        raise ValueError("Length of crops does not match length of labels")

    return crops, labels


def get_transform(image_width: int, augment: bool = False) -> transforms.Compose:
    """Augment with brigthness, sligth rotation, slant, translation, scale, and Gaussian noise."""

    def embed_crop(crop: Image, augment: bool = augment, image_width: int = image_width) -> Image:
        # Crop is PIL.Image of dtype="L" (so value range is [0, 255])
        image = Image.new("L", (image_width, IMAGE_HEIGHT))

        # Resize crop.
        crop_width, crop_height = crop.size
        new_crop_height = IMAGE_HEIGHT
        new_crop_width = int(new_crop_height * (crop_width / crop_height))

        if augment:
            # Add random stretching
            new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1))
            new_crop_width = min(new_crop_width, image_width)
        crop_resized = crop.resize(
            (new_crop_width, new_crop_height), resample=Image.BILINEAR
        )

        # Embed in image
        x = min(28, image_width - new_crop_width)
        y = IMAGE_HEIGHT - new_crop_height
        image.paste(crop_resized, (x, y))

        return image

    transfroms_list = [transforms.Lambda(embed_crop)]

    if augment:
        transfroms_list += [
            transforms.ColorJitter(brightness=(0.8, 1.6)),
            transforms.RandomAffine(
                degrees=1,
                shear=(-30, 20),
                interpolation=InterpolationMode.BILINEAR,
                fill=0,
            ),
        ]
    transfroms_list.append(transforms.ToTensor())
    return transforms.Compose(transfroms_list)


def generate_iam_lines() -> None:
    load_and_print_info(IAMLines)