diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-27 00:08:04 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-27 00:08:04 +0200 |
commit | 27ff7d113108e9cc51ddc5ff13b648b9c75fa865 (patch) | |
tree | 96b35c2f65978b8718665aaded3d29f00aaf43e2 /text_recognizer/data/iam_paragraphs.py | |
parent | 3227735099f8acb37ffe658b8f04b6c308b64d23 (diff) |
Add metadata
Diffstat (limited to 'text_recognizer/data/iam_paragraphs.py')
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 38 |
1 files changed, 14 insertions, 24 deletions
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index c7d5229..eec1b1f 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -18,17 +18,7 @@ from text_recognizer.data.base_dataset import ( from text_recognizer.data.iam import IAM from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file - -PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_paragraphs" - -NEW_LINE_TOKEN = "\n" - -SEED = 4711 -IMAGE_SCALE_FACTOR = 2 -IMAGE_HEIGHT = 1152 // IMAGE_SCALE_FACTOR -IMAGE_WIDTH = 1280 // IMAGE_SCALE_FACTOR -MAX_LABEL_LENGTH = 682 -MAX_WORD_PIECE_LENGTH = 451 +from text_recognizer.metadata import iam_paragraphs as metadata class IAMParagraphs(BaseDataModule): @@ -55,17 +45,17 @@ class IAMParagraphs(BaseDataModule): num_workers, pin_memory, ) - self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) - self.output_dims = (MAX_LABEL_LENGTH, 1) + self.dims = (1, metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH) + self.output_dims = (metadata.MAX_LABEL_LENGTH, 1) def prepare_data(self) -> None: """Create data for training/testing.""" - if PROCESSED_DATA_DIRNAME.exists(): + if metadata.PROCESSED_DATA_DIRNAME.exists(): return log.info("Cropping IAM paragraph regions and saving them along with labels...") - iam = IAM(mapping=EmnistMapping(extra_symbols={NEW_LINE_TOKEN})) + iam = IAM(mapping=EmnistMapping(extra_symbols={metadata.NEW_LINE_TOKEN})) iam.prepare_data() properties = {} @@ -84,7 +74,7 @@ class IAMParagraphs(BaseDataModule): } ) - with (PROCESSED_DATA_DIRNAME / "_properties.json").open("w") as f: + with (metadata.PROCESSED_DATA_DIRNAME / "_properties.json").open("w") as f: json.dump(properties, f, indent=4) def setup(self, stage: str = None) -> None: @@ -94,7 +84,7 @@ class IAMParagraphs(BaseDataModule): split: str, transform: T.Compose, target_transform: T.Compose ) -> BaseDataset: crops, labels = _load_processed_crops_and_labels(split) - data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops] + data = [resize_image(crop, metadata.IMAGE_SCALE_FACTOR) for crop in crops] targets = convert_strings_to_labels( strings=labels, mapping=self.mapping.inverse_mapping, @@ -117,7 +107,7 @@ class IAMParagraphs(BaseDataModule): target_transform=self.target_transform, ) self.data_train, self.data_val = split_dataset( - dataset=data_train, fraction=self.train_fraction, seed=SEED + dataset=data_train, fraction=self.train_fraction, seed=metadata.SEED ) if stage == "test" or stage is None: @@ -162,7 +152,7 @@ class IAMParagraphs(BaseDataModule): def get_dataset_properties() -> Dict: """Return properties describing the overall dataset.""" - with (PROCESSED_DATA_DIRNAME / "_properties.json").open("r") as f: + with (metadata.PROCESSED_DATA_DIRNAME / "_properties.json").open("r") as f: properties = json.load(f) def _get_property_values(key: str) -> List: @@ -193,7 +183,7 @@ def _validate_data_dims( """Validates input and output dimensions against the properties of the dataset.""" properties = get_dataset_properties() - max_image_shape = properties["crop_shape"]["max"] / IMAGE_SCALE_FACTOR + max_image_shape = properties["crop_shape"]["max"] / metadata.IMAGE_SCALE_FACTOR if ( input_dims is not None and input_dims[1] < max_image_shape[0] @@ -246,7 +236,7 @@ def _get_paragraph_crops_and_labels( lines = iam.line_strings_by_id[id_] crops[id_] = image.crop(paragraph_box) - labels[id_] = NEW_LINE_TOKEN.join(lines) + labels[id_] = metadata.NEW_LINE_TOKEN.join(lines) if len(crops) != len(labels): raise ValueError(f"Crops ({len(crops)}) does not match labels ({len(labels)})") @@ -258,7 +248,7 @@ def _save_crops_and_labels( crops: Dict[str, Image.Image], labels: Dict[str, str], split: str ) -> None: """Save crops, labels, and shapes of crops of a split.""" - (PROCESSED_DATA_DIRNAME / split).mkdir(parents=True, exist_ok=True) + (metadata.PROCESSED_DATA_DIRNAME / split).mkdir(parents=True, exist_ok=True) with _labels_filename(split).open("w") as f: json.dump(labels, f, indent=4) @@ -289,12 +279,12 @@ def _load_processed_crops_and_labels( def _labels_filename(split: str) -> Path: """Return filename of processed labels.""" - return PROCESSED_DATA_DIRNAME / split / "_labels.json" + return metadata.PROCESSED_DATA_DIRNAME / split / "_labels.json" def _crop_filename(id: str, split: str) -> Path: """Return filename of processed crop.""" - return PROCESSED_DATA_DIRNAME / split / f"{id}.png" + return metadata.PROCESSED_DATA_DIRNAME / split / f"{id}.png" def _num_lines(label: str) -> int: |