diff options
Diffstat (limited to 'text_recognizer/data/iam_synthetic_paragraphs.py')
-rw-r--r-- | text_recognizer/data/iam_synthetic_paragraphs.py | 56 |
1 files changed, 31 insertions, 25 deletions
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index 52ed398..91fda4a 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -9,25 +9,20 @@ from PIL import Image from text_recognizer.data.base_data_module import load_and_print_info from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels from text_recognizer.data.iam import IAM -from text_recognizer.data.iam_lines import ( - line_crops_and_labels, - load_line_crops_and_labels, - save_images_and_labels, -) from text_recognizer.data.iam_paragraphs import ( - IMAGE_SCALE_FACTOR, - NEW_LINE_TOKEN, IAMParagraphs, get_dataset_properties, resize_image, ) -from text_recognizer.data.mappings import EmnistMapping -from text_recognizer.data.transforms.load_transform import load_transform_from_file -from text_recognizer.metadata import shared as metadata - -PROCESSED_DATA_DIRNAME = ( - metadata.DATA_DIRNAME / "processed" / "iam_synthetic_paragraphs" +from text_recognizer.data.iam_lines import ( + line_crops_and_labels, + load_line_crops_and_labels, + save_images_and_labels, ) +from text_recognizer.data.mappings import EmnistMapping +from text_recognizer.data.stems.paragraph import ParagraphStem +from text_recognizer.data.transforms.pad import Pad +import text_recognizer.metadata.iam_synthetic_paragraphs as metadata class IAMSyntheticParagraphs(IAMParagraphs): @@ -57,26 +52,32 @@ class IAMSyntheticParagraphs(IAMParagraphs): def prepare_data(self) -> None: """Prepare IAM lines to be used to generate paragraphs.""" - if PROCESSED_DATA_DIRNAME.exists(): + if metadata.PROCESSED_DATA_DIRNAME.exists(): return log.info("Preparing IAM lines for synthetic paragraphs dataset.") log.info("Cropping IAM line regions and loading labels.") - iam = IAM(mapping=EmnistMapping(extra_symbols=(NEW_LINE_TOKEN,))) + iam = IAM(mapping=EmnistMapping(extra_symbols=(metadata.NEW_LINE_TOKEN,))) iam.prepare_data() crops_train, labels_train = line_crops_and_labels(iam, "train") crops_test, labels_test = line_crops_and_labels(iam, "test") - crops_train = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_train] - crops_test = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_test] + crops_train = [ + resize_image(crop, metadata.IMAGE_SCALE_FACTOR) for crop in crops_train + ] + crops_test = [ + resize_image(crop, metadata.IMAGE_SCALE_FACTOR) for crop in crops_test + ] - log.info(f"Saving images and labels at {PROCESSED_DATA_DIRNAME}") + log.info(f"Saving images and labels at {metadata.PROCESSED_DATA_DIRNAME}") save_images_and_labels( - crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME + crops_train, labels_train, "train", metadata.PROCESSED_DATA_DIRNAME + ) + save_images_and_labels( + crops_test, labels_test, "test", metadata.PROCESSED_DATA_DIRNAME ) - save_images_and_labels(crops_test, labels_test, "test", PROCESSED_DATA_DIRNAME) def setup(self, stage: str = None) -> None: """Loading synthetic dataset.""" @@ -85,7 +86,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): if stage == "fit" or stage is None: line_crops, line_labels = load_line_crops_and_labels( - "train", PROCESSED_DATA_DIRNAME + "train", metadata.PROCESSED_DATA_DIRNAME ) data, paragraphs_labels = generate_synthetic_paragraphs( line_crops=line_crops, line_labels=line_labels @@ -157,7 +158,7 @@ def generate_synthetic_paragraphs( paragraphs_crops, paragraphs_labels = [], [] for paragraph_indices in batched_indices_list: - paragraph_label = NEW_LINE_TOKEN.join( + paragraph_label = metadata.NEW_LINE_TOKEN.join( [line_labels[i] for i in paragraph_indices] ) if len(paragraph_label) > paragraphs_properties["label_length"]["max"]: @@ -236,8 +237,13 @@ def generate_random_batches( def create_synthetic_iam_paragraphs() -> None: """Creates and prints IAM Synthetic Paragraphs dataset.""" - transform = load_transform_from_file("transform/paragraphs.yaml") - test_transform = load_transform_from_file("test_transform/paragraphs.yaml") + transform = ParagraphStem() + test_transform = ParagraphStem() + target_transform = Pad(metadata.MAX_LABEL_LENGTH, 3) load_and_print_info( - IAMSyntheticParagraphs(transform=transform, test_transform=test_transform) + IAMSyntheticParagraphs( + transform=transform, + test_transform=test_transform, + target_transform=target_transform, + ) ) |