summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam.py')
-rw-r--r--text_recognizer/data/iam.py39
1 files changed, 26 insertions, 13 deletions
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index fcfe9a7..01272ba 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -60,23 +60,36 @@ class IAM(BaseDataModule):
@property
def split_by_id(self) -> Dict[str, str]:
- return {filename.stem: "test" if filename.stem in self.metadata["test_ids"] else "trainval" for filename in self.form_filenames}
+ return {
+ filename.stem: "test"
+ if filename.stem in self.metadata["test_ids"]
+ else "train"
+ for filename in self.form_filenames
+ }
@cachedproperty
def line_strings_by_id(self) -> Dict[str, List[str]]:
"""Return a dict from name of IAM form to list of line texts in it."""
- return {filename.stem: _get_line_strings_from_xml_file(filename) for filename in self.xml_filenames}
+ return {
+ filename.stem: _get_line_strings_from_xml_file(filename)
+ for filename in self.xml_filenames
+ }
@cachedproperty
def line_regions_by_id(self) -> Dict[str, List[Dict[str, int]]]:
"""Return a dict from name IAM form to list of (x1, x2, y1, y2) coordinates of all lines in it."""
- return {filename.stem: _get_line_regions_from_xml_file(filename) for filename in self.xml_filenames}
+ return {
+ filename.stem: _get_line_regions_from_xml_file(filename)
+ for filename in self.xml_filenames
+ }
def __repr__(self) -> str:
"""Return info about the dataset."""
- return ("IAM Dataset\n"
- f"Num forms total: {len(self.xml_filenames)}\n"
- f"Num in test set: {len(self.metadata['test_ids'])}\n")
+ return (
+ "IAM Dataset\n"
+ f"Num forms total: {len(self.xml_filenames)}\n"
+ f"Num in test set: {len(self.metadata['test_ids'])}\n"
+ )
def _extract_raw_dataset(filename: Path, dirname: Path) -> None:
@@ -92,7 +105,7 @@ def _get_line_strings_from_xml_file(filename: str) -> List[str]:
"""Get the text content of each line. Note that we replace &quot: with "."""
xml_root_element = ElementTree.parse(filename).getroot() # nosec
xml_line_elements = xml_root_element.findall("handwritten-part/line")
- return [el.attrib["text"].replace("&quot", '"') for el in xml_line_elements]
+ return [el.attrib["text"].replace(""", '"') for el in xml_line_elements]
def _get_line_regions_from_xml_file(filename: str) -> List[Dict[str, int]]:
@@ -107,13 +120,13 @@ def _get_line_region_from_xml_file(xml_line: Any) -> Dict[str, int]:
x1s = [int(el.attrib["x"]) for el in word_elements]
y1s = [int(el.attrib["y"]) for el in word_elements]
x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements]
- y2s = [int(el.attrib["x"]) + int(el.attrib["height"]) for el in word_elements]
+ y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements]
return {
- "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
- "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
- "x2": min(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
- "y2": min(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
- }
+ "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
+ "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
+ "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
+ "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
+ }
def download_iam() -> None: