summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_lines.py
blob: 34cf6050b97a019152fc00a8ff6f76a26c1970fc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
"""Class for IAM Lines dataset.

If not created, will generate a handwritten lines dataset from the IAM paragraphs
dataset.
"""
import json
from pathlib import Path
from typing import List, Sequence, Tuple

from attrs import define, field
from loguru import logger as log
import numpy as np
from PIL import Image, ImageFile, ImageOps
from torch import Tensor

from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.base_dataset import (
    BaseDataset,
    convert_strings_to_labels,
    split_dataset,
)
from text_recognizer.data.iam import IAM
from text_recognizer.data.mappings.emnist import EmnistMapping
from text_recognizer.data.transforms.load_transform import load_transform_from_file
from text_recognizer.data.utils import image_utils


ImageFile.LOAD_TRUNCATED_IMAGES = True

SEED = 4711
PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_lines"
IMAGE_HEIGHT = 56
IMAGE_WIDTH = 1024
MAX_LABEL_LENGTH = 89
MAX_WORD_PIECE_LENGTH = 72


@define(auto_attribs=True, repr=False)
class IAMLines(BaseDataModule):
    """IAM handwritten lines dataset."""

    dims: Tuple[int, int, int] = field(
        init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)
    )
    output_dims: Tuple[int, int] = field(init=False, default=(MAX_LABEL_LENGTH, 1))

    def prepare_data(self) -> None:
        """Creates the IAM lines dataset if not existing."""
        if PROCESSED_DATA_DIRNAME.exists():
            return

        log.info("Cropping IAM lines regions...")
        iam = IAM(mapping=EmnistMapping())
        iam.prepare_data()
        crops_train, labels_train = line_crops_and_labels(iam, "train")
        crops_test, labels_test = line_crops_and_labels(iam, "test")

        shapes = np.array([crop.size for crop in crops_train + crops_test])
        aspect_ratios = shapes[:, 0] / shapes[:, 1]

        log.info("Saving images, labels, and statistics...")
        save_images_and_labels(
            crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME
        )
        save_images_and_labels(crops_test, labels_test, "test", PROCESSED_DATA_DIRNAME)

        with (PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(mode="w") as f:
            f.write(str(aspect_ratios.max()))

    def setup(self, stage: str = None) -> None:
        """Load data for training/testing."""
        with (PROCESSED_DATA_DIRNAME / "_max_aspect_ratio.txt").open(mode="r") as f:
            max_aspect_ratio = float(f.read())
            image_width = int(IMAGE_HEIGHT * max_aspect_ratio)
            if image_width >= IMAGE_WIDTH:
                raise ValueError("image_width equal or greater than IMAGE_WIDTH")

        if stage == "fit" or stage is None:
            x_train, labels_train = load_line_crops_and_labels(
                "train", PROCESSED_DATA_DIRNAME
            )
            if self.output_dims[0] < max([len(labels) for labels in labels_train]) + 2:
                raise ValueError("Target length longer than max output length.")

            y_train = convert_strings_to_labels(
                labels_train, self.mapping.inverse_mapping, length=self.output_dims[0]
            )
            data_train = BaseDataset(
                x_train,
                y_train,
                transform=self.transform,
                target_transform=self.target_transform,
            )

            self.data_train, self.data_val = split_dataset(
                dataset=data_train, fraction=self.train_fraction, seed=SEED
            )

        if stage == "test" or stage is None:
            x_test, labels_test = load_line_crops_and_labels(
                "test", PROCESSED_DATA_DIRNAME
            )

            if self.output_dims[0] < max([len(labels) for labels in labels_test]) + 2:
                raise ValueError("Taget length longer than max output length.")

            y_test = convert_strings_to_labels(
                labels_test, self.mapping.inverse_mapping, length=self.output_dims[0]
            )
            self.data_test = BaseDataset(
                x_test,
                y_test,
                transform=self.test_transform,
                target_transform=self.target_transform,
            )

        if stage is None:
            self._verify_output_dims(labels_train, labels_test)

    def _verify_output_dims(self, labels_train: Tensor, labels_test: Tensor) -> None:
        max_label_length = max([len(label) for label in labels_train + labels_test]) + 2
        output_dims = (max_label_length, 1)
        if output_dims != self.output_dims:
            raise ValueError("Output dim does not match expected output dims.")

    def __repr__(self) -> str:
        """Return information about the dataset."""
        basic = (
            "IAM Lines dataset\n"
            f"Num classes: {len(self.mapping)}\n"
            f"Input dims: {self.dims}\n"
            f"Output dims: {self.output_dims}\n"
        )

        if not any([self.data_train, self.data_val, self.data_test]):
            return basic

        x, y = next(iter(self.train_dataloader()))
        xt, yt = next(iter(self.test_dataloader()))
        x = x[0] if isinstance(x, list) else x
        xt = xt[0] if isinstance(xt, list) else xt
        data = (
            "Train/val/test sizes: "
            f"{len(self.data_train)}, "
            f"{len(self.data_val)}, "
            f"{len(self.data_test)}\n"
            "Train Batch x stats: "
            f"{(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
            "Train Batch y stats: "
            f"{(y.shape, y.dtype, y.min(), y.max())}\n"
            "Test Batch x stats: "
            f"{(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n"
            "Test Batch y stats: "
            f"{(yt.shape, yt.dtype, yt.min(), yt.max())}\n"
        )
        return basic + data


def line_crops_and_labels(iam: IAM, split: str) -> Tuple[List, List]:
    """Load IAM line labels and regions, and load image crops."""
    crops = []
    labels = []
    for filename in iam.form_filenames:
        if not iam.split_by_id[filename.stem] == split:
            continue
        image = image_utils.read_image_pil(filename)
        image = ImageOps.grayscale(image)
        image = ImageOps.invert(image)
        labels += iam.line_strings_by_id[filename.stem]
        crops += [
            image.crop([region[box] for box in ["x1", "y1", "x2", "y2"]])
            for region in iam.line_regions_by_id[filename.stem]
        ]
    if len(crops) != len(labels):
        raise ValueError("Length of crops does not match length of labels")
    return crops, labels


def save_images_and_labels(
    crops: Sequence[Image.Image], labels: Sequence[str], split: str, data_dirname: Path
) -> None:
    """Saves generated images and labels to disk."""
    (data_dirname / split).mkdir(parents=True, exist_ok=True)

    with (data_dirname / split / "_labels.json").open(mode="w") as f:
        json.dump(labels, f)

    for index, crop in enumerate(crops):
        crop.save(data_dirname / split / f"{index}.png")


def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, List]:
    """Load line crops and labels for given split from processed directoru."""
    with (data_dirname / split / "_labels.json").open(mode="r") as f:
        labels = json.load(f)

    crop_filenames = sorted(
        (data_dirname / split).glob("*.png"),
        key=lambda filename: int(Path(filename).stem),
    )
    crops = [
        image_utils.read_image_pil(filename, grayscale=True)
        for filename in crop_filenames
    ]

    if len(crops) != len(labels):
        raise ValueError("Length of crops does not match length of labels")

    return crops, labels


def generate_iam_lines() -> None:
    """Displays Iam Lines dataset statistics."""
    transform = load_transform_from_file("transform/lines.yaml")
    test_transform = load_transform_from_file("test_transform/lines.yaml")
    load_and_print_info(IAMLines(transform=transform, test_transform=test_transform))