summaryrefslogtreecommitdiff
path: root/src/text_recognizer/datasets/iam_dataset.py
blob: 5e4735067d4e1c3c7258c075dd610b73509b58c5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities."""
import os
from typing import Any, Dict, List
import zipfile

from boltons.cacheutils import cachedproperty
import defusedxml.ElementTree as ET
from loguru import logger
import toml
from torch.utils.data import Dataset

from text_recognizer.datasets import DATA_DIRNAME
from text_recognizer.datasets.util import _download_raw_dataset

RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam"
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb"

DOWNSAMPLE_FACTOR = 2  # If images were downsampled, the regions must also be.
LINE_REGION_PADDING = 0  # Add this many pixels around the exact coordinates.


class IamDataset(Dataset):
    """IAM dataset.

    "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text,
    which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels."
    From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database

    The data split we will use is
    IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines.
        The validation set has been merged into the train set.
        The train set has 7,101 lines from 326 writers.
        The test set has 1,861 lines from 128 writers.
        The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only.

    """

    def __init__(self) -> None:
        self.metadata = toml.load(METADATA_FILENAME)

    def load_or_generate_data(self) -> None:
        """Downloads IAM dataset if xml files does not exist."""
        if not self.xml_filenames:
            self._download_iam()

    @property
    def xml_filenames(self) -> List:
        """List of xml filenames."""
        return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml"))

    @property
    def form_filenames(self) -> List:
        """List of forms filenames."""
        return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg"))

    def _download_iam(self) -> None:
        curdir = os.getcwd()
        os.chdir(RAW_DATA_DIRNAME)
        _download_raw_dataset(self.metadata)
        _extract_raw_dataset(self.metadata)
        os.chdir(curdir)

    @property
    def form_filenames_by_id(self) -> Dict:
        """Creates a dictionary with filenames as keys and forms as values."""
        return {filename.stem: filename for filename in self.form_filenames}

    @cachedproperty
    def line_strings_by_id(self) -> Dict:
        """Return a dict from name of IAM form to a list of line texts in it."""
        return {
            filename.stem: _get_line_strings_from_xml_file(filename)
            for filename in self.xml_filenames
        }

    @cachedproperty
    def line_regions_by_id(self) -> Dict:
        """Return a dict from name of IAM form to a list of (x1, x2, y1, y2) coordinates of all lines in it."""
        return {
            filename.stem: _get_line_regions_from_xml_file(filename)
            for filename in self.xml_filenames
        }

    def __repr__(self) -> str:
        """Print info about dataset."""
        return "IAM Dataset\n" f"Number of forms: {len(self.xml_filenames)}\n"


def _extract_raw_dataset(metadata: Dict) -> None:
    logger.info("Extracting IAM data.")
    with zipfile.ZipFile(metadata["filename"], "r") as zip_file:
        zip_file.extractall()


def _get_line_strings_from_xml_file(filename: str) -> List[str]:
    """Get the text content of each line. Note that we replace " with "."""
    xml_root_element = ET.parse(filename).getroot()  # nosec
    xml_line_elements = xml_root_element.findall("handwritten-part/line")
    return [el.attrib["text"].replace(""", '"') for el in xml_line_elements]


def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]:
    """Get the line region dict for each line."""
    xml_root_element = ET.parse(filename).getroot()  # nosec
    xml_line_elements = xml_root_element.findall("handwritten-part/line")
    return [_get_line_region_from_xml_element(el) for el in xml_line_elements]


def _get_line_region_from_xml_element(xml_line: Any) -> Dict[str, int]:
    """Extracts coordinates for each line of text."""
    # TODO: fix input!
    word_elements = xml_line.findall("word/cmp")
    x1s = [int(el.attrib["x"]) for el in word_elements]
    y1s = [int(el.attrib["y"]) for el in word_elements]
    x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements]
    y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements]
    return {
        "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
        "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
        "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
        "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
    }


def main() -> None:
    """Initializes the dataset and print info about the dataset."""
    dataset = IamDataset()
    dataset.load_or_generate_data()
    print(dataset)


if __name__ == "__main__":
    main()