diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-28 22:02:24 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-28 22:02:24 +0200 |
commit | 46a1472d33d3a4180798492e819f2ec02bc3b1a3 (patch) | |
tree | 22322ed0d8f9f803966ea745ec5bb8c759f8db64 /text_recognizer/data/iam.py | |
parent | 8248f173132dfb7e47ec62b08e9235990c8626e3 (diff) |
Add refactor of iam lines
Diffstat (limited to 'text_recognizer/data/iam.py')
-rw-r--r-- | text_recognizer/data/iam.py | 39 |
1 files changed, 26 insertions, 13 deletions
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index fcfe9a7..01272ba 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -60,23 +60,36 @@ class IAM(BaseDataModule): @property def split_by_id(self) -> Dict[str, str]: - return {filename.stem: "test" if filename.stem in self.metadata["test_ids"] else "trainval" for filename in self.form_filenames} + return { + filename.stem: "test" + if filename.stem in self.metadata["test_ids"] + else "train" + for filename in self.form_filenames + } @cachedproperty def line_strings_by_id(self) -> Dict[str, List[str]]: """Return a dict from name of IAM form to list of line texts in it.""" - return {filename.stem: _get_line_strings_from_xml_file(filename) for filename in self.xml_filenames} + return { + filename.stem: _get_line_strings_from_xml_file(filename) + for filename in self.xml_filenames + } @cachedproperty def line_regions_by_id(self) -> Dict[str, List[Dict[str, int]]]: """Return a dict from name IAM form to list of (x1, x2, y1, y2) coordinates of all lines in it.""" - return {filename.stem: _get_line_regions_from_xml_file(filename) for filename in self.xml_filenames} + return { + filename.stem: _get_line_regions_from_xml_file(filename) + for filename in self.xml_filenames + } def __repr__(self) -> str: """Return info about the dataset.""" - return ("IAM Dataset\n" - f"Num forms total: {len(self.xml_filenames)}\n" - f"Num in test set: {len(self.metadata['test_ids'])}\n") + return ( + "IAM Dataset\n" + f"Num forms total: {len(self.xml_filenames)}\n" + f"Num in test set: {len(self.metadata['test_ids'])}\n" + ) def _extract_raw_dataset(filename: Path, dirname: Path) -> None: @@ -92,7 +105,7 @@ def _get_line_strings_from_xml_file(filename: str) -> List[str]: """Get the text content of each line. Note that we replace ": with ".""" xml_root_element = ElementTree.parse(filename).getroot() # nosec xml_line_elements = xml_root_element.findall("handwritten-part/line") - return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] + return [el.attrib["text"].replace(""", '"') for el in xml_line_elements] def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]: @@ -107,13 +120,13 @@ def _get_line_region_from_xml_file(xml_line: Any) -> Dict[str, int]: x1s = [int(el.attrib["x"]) for el in word_elements] y1s = [int(el.attrib["y"]) for el in word_elements] x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements] - y2s = [int(el.attrib["x"]) + int(el.attrib["height"]) for el in word_elements] + y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements] return { - "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, - "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, - "x2": min(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, - "y2": min(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, - } + "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING, + "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING, + } def download_iam() -> None: |