"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities.""" import os from pathlib import Path from typing import Any, Dict, List import xml.etree.ElementTree as ElementTree import zipfile from boltons.cacheutils import cachedproperty from loguru import logger from PIL import Image import toml from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.download_utils import download_dataset RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "iam" METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "iam" EXTRACTED_DATASET_DIRNAME = DL_DATA_DIRNAME / "iamdb" DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates. class IAM(BaseDataModule): """ "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels. From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database The data split we will use is IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines. The validation set has been merged into the train set. The train set has 7,101 lines from 326 writers. The test set has 1,861 lines from 128 writers. The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. """ def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None: super().__init__(batch_size, num_workers) self.metadata = toml.load(METADATA_FILENAME) def prepare_data(self) -> None: if self.xml_filenames: return filename = download_dataset(self.metadata, DL_DATA_DIRNAME) _extract_raw_dataset(filename, DL_DATA_DIRNAME) @property def xml_filenames(self) -> List[Path]: return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) @property def form_filenames(self) -> List[Path]: return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) @property def form_filenames_by_id(self) -> Dict[str, Path]: return {filename.stem: filename for filename in self.form_filenames} @property def split_by_id(self) -> Dict[str, str]: return { filename.stem: "test" if filename.stem in self.metadata["test_ids"] else "train" for filename in self.form_filenames } @cachedproperty def line_strings_by_id(self) -> Dict[str, List[str]]: """Return a dict from name of IAM form to list of line texts in it.""" return { filename.stem: _get_line_strings_from_xml_file(filename) for filename in self.xml_filenames } @cachedproperty def line_regions_by_id(self) -> Dict[str, List[Dict[str, int]]]: """Return a dict from name IAM form to list of (x1, x2, y1, y2) coordinates of all lines in it.""" return { filename.stem: _get_line_regions_from_xml_file(filename) for filename in self.xml_filenames } def __repr__(self) -> str: """Return info about the dataset.""" return ( "IAM Dataset\n" f"Num forms total: {len(self.xml_filenames)}\n" f"Num in test set: {len(self.metadata['test_ids'])}\n" ) def _extract_raw_dataset(filename: Path, dirname: Path) -> None: logger.info("Extracting IAM data...") curdir = os.getcwd() os.chdir(dirname) with zipfile.ZipFile(filename, "r") as f: f.extractall() os.chdir(curdir) def _get_line_strings_from_xml_file(filename: str) -> List[str]: """Get the text content of each line. Note that we replace ": with ".""" xml_root_element = ElementTree.parse(filename).getroot() # nosec xml_line_elements = xml_root_element.findall("handwritten-part/line") return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: """Get line region dict for each line.""" xml_root_element = ElementTree.parse(filename).getroot() # nosec xml_line_elements = xml_root_element.findall("handwritten-part/line") return [_get_line_region_from_xml_file(el) for el in xml_line_elements] def _get_line_region_from_xml_file(xml_line: Any) -> Dict[str, int]: word_elements = xml_line.findall("word/cmp") x1s = [int(el.attrib["x"]) for el in word_elements] y1s = [int(el.attrib["y"]) for el in word_elements] x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements] return { "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, } def download_iam() -> None: load_and_print_info(IAM)