summaryrefslogtreecommitdiff
path: root/text_recognizer/data/iam.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/iam.py')
-rw-r--r--text_recognizer/data/iam.py27
1 files changed, 10 insertions, 17 deletions
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index e3baf88..c20b50b 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -15,14 +15,7 @@ from loguru import logger as log
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.utils.download_utils import download_dataset
-
-RAW_DATA_DIRNAME = BaseDataModule.data_dirname() / "raw" / "iam"
-METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
-DL_DATA_DIRNAME = BaseDataModule.data_dirname() / "downloaded" / "iam"
-EXTRACTED_DATASET_DIRNAME = DL_DATA_DIRNAME / "iamdb"
-
-DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be.
-LINE_REGION_PADDING = 16 # Add this many pixels around the exact coordinates.
+from text_recognizer.metadata import iam as metadata
class IAM(BaseDataModule):
@@ -44,24 +37,24 @@ class IAM(BaseDataModule):
def __init__(self) -> None:
super().__init__()
- self.metadata: Dict = toml.load(METADATA_FILENAME)
+ self.metadata: Dict = toml.load(metadata.METADATA_FILENAME)
def prepare_data(self) -> None:
"""Prepares the IAM dataset."""
if self.xml_filenames:
return
- filename = download_dataset(self.metadata, DL_DATA_DIRNAME)
- _extract_raw_dataset(filename, DL_DATA_DIRNAME)
+ filename = download_dataset(self.metadata, metadata.DL_DATA_DIRNAME)
+ _extract_raw_dataset(filename, metadata.DL_DATA_DIRNAME)
@property
def xml_filenames(self) -> List[Path]:
"""Returns the xml filenames."""
- return list((EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml"))
+ return list((metadata.EXTRACTED_DATASET_DIRNAME / "xml").glob("*.xml"))
@property
def form_filenames(self) -> List[Path]:
"""Returns the form filenames."""
- return list((EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg"))
+ return list((metadata.EXTRACTED_DATASET_DIRNAME / "forms").glob("*.jpg"))
@property
def form_filenames_by_id(self) -> Dict[str, Path]:
@@ -133,10 +126,10 @@ def _get_line_region_from_xml_file(xml_line: Any) -> Dict[str, int]:
x2s = [int(el.attrib["x"]) + int(el.attrib["width"]) for el in word_elements]
y2s = [int(el.attrib["y"]) + int(el.attrib["height"]) for el in word_elements]
return {
- "x1": min(x1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
- "y1": min(y1s) // DOWNSAMPLE_FACTOR - LINE_REGION_PADDING,
- "x2": max(x2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
- "y2": max(y2s) // DOWNSAMPLE_FACTOR + LINE_REGION_PADDING,
+ "x1": min(x1s) // metadata.DOWNSAMPLE_FACTOR - metadata.LINE_REGION_PADDING,
+ "y1": min(y1s) // metadata.DOWNSAMPLE_FACTOR - metadata.LINE_REGION_PADDING,
+ "x2": max(x2s) // metadata.DOWNSAMPLE_FACTOR + metadata.LINE_REGION_PADDING,
+ "y2": max(y2s) // metadata.DOWNSAMPLE_FACTOR + metadata.LINE_REGION_PADDING,
}