summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam_paragraphs_dataset.py
blob: 8ba5142e0ee861022dff35e248a4937087f2436f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
"""IamParagraphsDataset class and functions for data processing."""
import random
from typing import Callable, Dict, List, Optional, Tuple, Union

import click
import cv2
import h5py
from loguru import logger
import numpy as np
import torch
from torch import Tensor
from torchvision.transforms import ToTensor

from text_recognizer import util
from text_recognizer.datasets.dataset import Dataset
from text_recognizer.datasets.iam_dataset import IamDataset
from text_recognizer.datasets.util import (
    compute_sha256,
    DATA_DIRNAME,
    download_url,
    EmnistMapper,
)

INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs"
DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops"
PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_paragraphs"
CROPS_DIRNAME = PROCESSED_DATA_DIRNAME / "crops"
GT_DIRNAME = PROCESSED_DATA_DIRNAME / "gt"

PARAGRAPH_BUFFER = 50  # Pixels in the IAM form images to leave around the lines.
TEST_FRACTION = 0.2
SEED = 4711


class IamParagraphsDataset(Dataset):
    """IAM Paragraphs dataset for paragraphs of handwritten text."""

    def __init__(
        self,
        train: bool = False,
        subsample_fraction: float = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
    ) -> None:
        super().__init__(
            train=train,
            subsample_fraction=subsample_fraction,
            transform=transform,
            target_transform=target_transform,
        )
        # Load Iam dataset.
        self.iam_dataset = IamDataset()

        self.num_classes = 3
        self._input_shape = (256, 256)
        self._output_shape = self._input_shape + (self.num_classes,)
        self._ids = None

    def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
        """Fetches data, target pair of the dataset for a given and index or indices.

        Args:
            index (Union[int, Tensor]): Either a list or int of indices/index.

        Returns:
            Tuple[Tensor, Tensor]: Data target pair.

        """
        if torch.is_tensor(index):
            index = index.tolist()

        data = self.data[index]
        targets = self.targets[index]

        seed = np.random.randint(SEED)
        random.seed(seed)  # apply this seed to target tranfsorms
        torch.manual_seed(seed)  # needed for torchvision 0.7
        if self.transform:
            data = self.transform(data)

        random.seed(seed)  # apply this seed to target tranfsorms
        torch.manual_seed(seed)  # needed for torchvision 0.7
        if self.target_transform:
            targets = self.target_transform(targets)

        return data, targets.long()

    @property
    def ids(self) -> Tensor:
        """Ids of the dataset."""
        return self._ids

    def get_data_and_target_from_id(self, id_: str) -> Tuple[Tensor, Tensor]:
        """Get data target pair from id."""
        ind = self.ids.index(id_)
        return self.data[ind], self.targets[ind]

    def load_or_generate_data(self) -> None:
        """Load or generate dataset data."""
        num_actual = len(list(CROPS_DIRNAME.glob("*.jpg")))
        num_targets = len(self.iam_dataset.line_regions_by_id)

        if num_actual < num_targets - 2:
            self._process_iam_paragraphs()

        self._data, self._targets, self._ids = _load_iam_paragraphs()
        self._get_random_split()
        self._subsample()

    def _get_random_split(self) -> None:
        np.random.seed(SEED)
        num_train = int((1 - TEST_FRACTION) * self.data.shape[0])
        indices = np.random.permutation(self.data.shape[0])
        train_indices, test_indices = indices[:num_train], indices[num_train:]
        if self.train:
            self._data = self.data[train_indices]
            self._targets = self.targets[train_indices]
        else:
            self._data = self.data[test_indices]
            self._targets = self.targets[test_indices]

    def _process_iam_paragraphs(self) -> None:
        """Crop the part with the text.

        For each page, crop out the part of it that correspond to the paragraph of text, and make sure all crops are
        self.input_shape. The ground truth data is the same size, with a one-hot vector at each pixel
        corresponding to labels 0=background, 1=odd-numbered line, 2=even-numbered line
        """
        crop_dims = self._decide_on_crop_dims()
        CROPS_DIRNAME.mkdir(parents=True, exist_ok=True)
        DEBUG_CROPS_DIRNAME.mkdir(parents=True, exist_ok=True)
        GT_DIRNAME.mkdir(parents=True, exist_ok=True)
        logger.info(
            f"Cropping paragraphs, generating ground truth, and saving debugging images to {DEBUG_CROPS_DIRNAME}"
        )
        for filename in self.iam_dataset.form_filenames:
            id_ = filename.stem
            line_region = self.iam_dataset.line_regions_by_id[id_]
            _crop_paragraph_image(filename, line_region, crop_dims, self.input_shape)

    def _decide_on_crop_dims(self) -> Tuple[int, int]:
        """Decide on the dimensions to crop out of the form image.

        Since image width is larger than a comfortable crop around the longest paragraph,
        we will make the crop a square form factor.
        And since the found dimensions 610x610 are pretty close to 512x512,
        we might as well resize crops and make it exactly that, which lets us
        do all kinds of power-of-2 pooling and upsampling should we choose to.

        Returns:
            Tuple[int, int]: A tuple of crop dimensions.

        Raises:
            RuntimeError: When max crop height is larger than max crop width.

        """

        sample_form_filename = self.iam_dataset.form_filenames[0]
        sample_image = util.read_image(sample_form_filename, grayscale=True)
        max_crop_width = sample_image.shape[1]
        max_crop_height = _get_max_paragraph_crop_height(
            self.iam_dataset.line_regions_by_id
        )
        if not max_crop_height <= max_crop_width:
            raise RuntimeError(
                f"Max crop height is larger then max crop width: {max_crop_height} >= {max_crop_width}"
            )

        crop_dims = (max_crop_width, max_crop_width)
        logger.info(
            f"Max crop width and height were found to be {max_crop_width}x{max_crop_height}."
        )
        logger.info(f"Setting them to {max_crop_width}x{max_crop_width}")
        return crop_dims

    def __repr__(self) -> str:
        """Return info about the dataset."""
        return (
            "IAM Paragraph Dataset\n"  # pylint: disable=no-member
            f"Num classes: {self.num_classes}\n"
            f"Data: {self.data.shape}\n"
            f"Targets: {self.targets.shape}\n"
        )


def _get_max_paragraph_crop_height(line_regions_by_id: Dict) -> int:
    heights = []
    for regions in line_regions_by_id.values():
        min_y1 = min(region["y1"] for region in regions) - PARAGRAPH_BUFFER
        max_y2 = max(region["y2"] for region in regions) + PARAGRAPH_BUFFER
        height = max_y2 - min_y1
        heights.append(height)
    return max(heights)


def _crop_paragraph_image(
    filename: str, line_regions: Dict, crop_dims: Tuple[int, int], final_dims: Tuple
) -> None:
    image = util.read_image(filename, grayscale=True)

    min_y1 = min(region["y1"] for region in line_regions) - PARAGRAPH_BUFFER
    max_y2 = max(region["y2"] for region in line_regions) + PARAGRAPH_BUFFER
    height = max_y2 - min_y1
    crop_height = crop_dims[0]
    buffer = (crop_height - height) // 2

    # Generate image crop.
    image_crop = 255 * np.ones(crop_dims, dtype=np.uint8)
    try:
        image_crop[buffer : buffer + height] = image[min_y1:max_y2]
    except Exception as e:  # pylint: disable=broad-except
        logger.error(f"Rescued {filename}: {e}")
        return

    # Generate ground truth.
    gt_image = np.zeros_like(image_crop, dtype=np.uint8)
    for index, region in enumerate(line_regions):
        gt_image[
            (region["y1"] - min_y1 + buffer) : (region["y2"] - min_y1 + buffer),
            region["x1"] : region["x2"],
        ] = (index % 2 + 1)

    # Generate image for debugging.
    import matplotlib.pyplot as plt

    cmap = plt.get_cmap("Set1")
    image_crop_for_debug = np.dstack([image_crop, image_crop, image_crop])
    for index, region in enumerate(line_regions):
        color = [255 * _ for _ in cmap(index)[:-1]]
        cv2.rectangle(
            image_crop_for_debug,
            (region["x1"], region["y1"] - min_y1 + buffer),
            (region["x2"], region["y2"] - min_y1 + buffer),
            color,
            3,
        )
    image_crop_for_debug = cv2.resize(
        image_crop_for_debug, final_dims, interpolation=cv2.INTER_AREA
    )
    util.write_image(image_crop_for_debug, DEBUG_CROPS_DIRNAME / f"{filename.stem}.jpg")

    image_crop = cv2.resize(image_crop, final_dims, interpolation=cv2.INTER_AREA)
    util.write_image(image_crop, CROPS_DIRNAME / f"{filename.stem}.jpg")

    gt_image = cv2.resize(gt_image, final_dims, interpolation=cv2.INTER_NEAREST)
    util.write_image(gt_image, GT_DIRNAME / f"{filename.stem}.png")


def _load_iam_paragraphs() -> None:
    logger.info("Loading IAM paragraph crops and ground truth from image files...")
    images = []
    gt_images = []
    ids = []
    for filename in CROPS_DIRNAME.glob("*.jpg"):
        id_ = filename.stem
        image = util.read_image(filename, grayscale=True)
        image = 1.0 - image / 255

        gt_filename = GT_DIRNAME / f"{id_}.png"
        gt_image = util.read_image(gt_filename, grayscale=True)

        images.append(image)
        gt_images.append(gt_image)
        ids.append(id_)
    images = np.array(images).astype(np.float32)
    gt_images = np.array(gt_images).astype(np.uint8)
    ids = np.array(ids)
    return images, gt_images, ids


@click.command()
@click.option(
    "--subsample_fraction",
    type=float,
    default=None,
    help="The subsampling factor of the dataset.",
)
def main(subsample_fraction: float) -> None:
    """Load dataset and print info."""
    logger.info("Creating train set...")
    dataset = IamParagraphsDataset(train=True, subsample_fraction=subsample_fraction)
    dataset.load_or_generate_data()
    print(dataset)
    logger.info("Creating test set...")
    dataset = IamParagraphsDataset(subsample_fraction=subsample_fraction)
    dataset.load_or_generate_data()
    print(dataset)


if __name__ == "__main__":
    main()