From 46a1472d33d3a4180798492e819f2ec02bc3b1a3 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 28 Mar 2021 22:02:24 +0200 Subject: Add refactor of iam lines --- text_recognizer/data/iam.py | 39 ++++++++++++++++++++++++++------------- 1 file changed, 26 insertions(+), 13 deletions(-) (limited to 'text_recognizer/data/iam.py') diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index fcfe9a7..01272ba 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -60,23 +60,36 @@ class IAM(BaseDataModule): @property def split_by_id(self) -> Dict[str, str]: - return {filename.stem: "test" if filename.stem in self.metadata["test_ids"] else "trainval" for filename in self.form_filenames} + return { + filename.stem: "test" + if filename.stem in self.metadata["test_ids"] + else "train" + for filename in self.form_filenames + } @cachedproperty def line_strings_by_id(self) -> Dict[str, List[str]]: """Return a dict from name of IAM form to list of line texts in it.""" - return {filename.stem: _get_line_strings_from_xml_file(filename) for filename in self.xml_filenames} + return { + filename.stem: _get_line_strings_from_xml_file(filename) + for filename in self.xml_filenames + } @cachedproperty def line_regions_by_id(self) -> Dict[str, List[Dict[str, int]]]: """Return a dict from name IAM form to list of (x1, x2, y1, y2) coordinates of all lines in it.""" - return {filename.stem: _get_line_regions_from_xml_file(filename) for filename in self.xml_filenames} + return { + filename.stem: _get_line_regions_from_xml_file(filename) + for filename in self.xml_filenames + } def __repr__(self) -> str: """Return info about the dataset.""" - return ("IAM Dataset\n" - f"Num forms total: {len(self.xml_filenames)}\n" - f"Num in test set: {len(self.metadata['test_ids'])}\n") + return ( + "IAM Dataset\n" + f"Num forms total: {len(self.xml_filenames)}\n" + f"Num in test set: {len(self.metadata['test_ids'])}\n" + ) def _extract_raw_dataset(filename: Path, dirname: Path) -> None: @@ -92,7 +105,7 @@ def _get_line_strings_from_xml_file(filename: str) -> List[str]: """Get the text content of each line. Note that we replace ": with ".""" xml_root_element = ElementTree.parse(filename).getroot() # nosec xml_line_elements = xml_root_element.findall("handwritten-part/line") - return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] + return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: @@ -107,13 +120,13 @@ def _get_line_region_from_xml_file(xml_line: Any) -> Dict[str, int]: x1s = [int(el.attrib["x"]) for el in word_elements] y1s = [int(el.attrib["y"]) for el in word_elements] x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] - y2s = [int(el.attrib["x"]) + int(el.attrib["height"]) for el in word_elements] + y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements] return { - "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, - "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, - "x2": min(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, - "y2": min(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, - } + "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + } def download_iam() -> None: -- cgit v1.2.3-70-g09d2