From 8248f173132dfb7e47ec62b08e9235990c8626e3 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 24 Mar 2021 22:15:54 +0100 Subject: renamed datasets to data, added iam refactor --- text_recognizer/data/iam.py | 120 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) create mode 100644 text_recognizer/data/iam.py (limited to 'text_recognizer/data/iam.py') diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py new file mode 100644 index 0000000..fcfe9a7 --- /dev/null +++ b/text_recognizer/data/iam.py @@ -0,0 +1,120 @@ +"""Class for loading the IAM dataset, which encompasses both paragraphs and lines, with associated utilities.""" +import os +from pathlib import Path +from typing import Any, Dict, List +import xml.etree.ElementTree as ElementTree +import zipfile + +from boltons.cacheutils import cachedproperty +from loguru import logger +from PIL import Image +import toml + +from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info +from text_recognizer.data.download_utils import download_dataset + + +RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "iam" +METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" +DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "iam" +EXTRACTED_DATASET_DIRNAME = DL_DATA_DIRNAME / "iamdb" + +DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. +LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates. + + +class IAM(BaseDataModule): + """ + "The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text, + which were scanned at a resolution of 300dpi and saved as PNG images with 256 gray levels. + From http://www.fki.inf.unibe.ch/databases/iam-handwriting-database + The data split we will use is + IAM lines Large Writer Independent Text Line Recognition Task (lwitlrt): 9,862 text lines. + The validation set has been merged into the train set. + The train set has 7,101 lines from 326 writers. + The test set has 1,861 lines from 128 writers. + The text lines of all data sets are mutually exclusive, thus each writer has contributed to one set only. + """ + + def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None: + super().__init__(batch_size, num_workers) + self.metadata = toml.load(METADATA_FILENAME) + + def prepare_data(self) -> None: + if self.xml_filenames: + return + filename = download_dataset(self.metadata, DL_DATA_DIRNAME) + _extract_raw_dataset(filename, DL_DATA_DIRNAME) + + @property + def xml_filenames(self) -> List[Path]: + return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml")) + + @property + def form_filenames(self) -> List[Path]: + return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg")) + + @property + def form_filenames_by_id(self) -> Dict[str, Path]: + return {filename.stem: filename for filename in self.form_filenames} + + @property + def split_by_id(self) -> Dict[str, str]: + return {filename.stem: "test" if filename.stem in self.metadata["test_ids"] else "trainval" for filename in self.form_filenames} + + @cachedproperty + def line_strings_by_id(self) -> Dict[str, List[str]]: + """Return a dict from name of IAM form to list of line texts in it.""" + return {filename.stem: _get_line_strings_from_xml_file(filename) for filename in self.xml_filenames} + + @cachedproperty + def line_regions_by_id(self) -> Dict[str, List[Dict[str, int]]]: + """Return a dict from name IAM form to list of (x1, x2, y1, y2) coordinates of all lines in it.""" + return {filename.stem: _get_line_regions_from_xml_file(filename) for filename in self.xml_filenames} + + def __repr__(self) -> str: + """Return info about the dataset.""" + return ("IAM Dataset\n" + f"Num forms total: {len(self.xml_filenames)}\n" + f"Num in test set: {len(self.metadata['test_ids'])}\n") + + +def _extract_raw_dataset(filename: Path, dirname: Path) -> None: + logger.info("Extracting IAM data...") + curdir = os.getcwd() + os.chdir(dirname) + with zipfile.ZipFile(filename, "r") as f: + f.extractall() + os.chdir(curdir) + + +def _get_line_strings_from_xml_file(filename: str) -> List[str]: + """Get the text content of each line. Note that we replace ": with ".""" + xml_root_element = ElementTree.parse(filename).getroot() # nosec + xml_line_elements = xml_root_element.findall("handwritten-part/line") + return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] + + +def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: + """Get line region dict for each line.""" + xml_root_element = ElementTree.parse(filename).getroot() # nosec + xml_line_elements = xml_root_element.findall("handwritten-part/line") + return [_get_line_region_from_xml_file(el) for el in xml_line_elements] + + +def _get_line_region_from_xml_file(xml_line: Any) -> Dict[str, int]: + word_elements = xml_line.findall("word/cmp") + x1s = [int(el.attrib["x"]) for el in word_elements] + y1s = [int(el.attrib["y"]) for el in word_elements] + x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] + y2s = [int(el.attrib["x"]) + int(el.attrib["height"]) for el in word_elements] + return { + "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "x2": min(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + "y2": min(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + } + + +def download_iam() -> None: + load_and_print_info(IAM) -- cgit v1.2.3-70-g09d2