diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-04 11:31:00 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-04 11:31:00 +0200 |
commit | 186edf0890953f070cf707b6c3aef26961e1721f (patch) | |
tree | 47dec0d107b4c6b6725a15f7c99bf9f71ae6e7f3 /text_recognizer/data/iam_synthetic_paragraphs.py | |
parent | 07f2cc3665a1a60efe8ed8073cad6ac4f344b2c2 (diff) |
Add synthetic iam paragraphs dataset
Diffstat (limited to 'text_recognizer/data/iam_synthetic_paragraphs.py')
-rw-r--r-- | text_recognizer/data/iam_synthetic_paragraphs.py | 224 |
1 files changed, 224 insertions, 0 deletions
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py new file mode 100644 index 0000000..9f1bd12 --- /dev/null +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -0,0 +1,224 @@ +"""IAM Synthetic Paragraphs Dataset class.""" +import itertools +from pathlib import Path +import random +import time +from typing import Any, List, Sequence, Tuple + +from loguru import logger +import numpy as np +from PIL import Image + +from text_recognizer.data.base_dataset import ( + BaseDataset, + convert_strings_to_labels, + split_dataset, +) +from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info +from text_recognizer.data.iam_paragraphs import ( + get_dataset_properties, + get_transform, + NEW_LINE_TOKEN, + IAMParagraphs, + IMAGE_SCALE_FACTOR, + resize_image, +) +from text_recognizer.data.iam import IAM +from text_recognizer.data.iam_lines import ( + line_crops_and_labels, + load_line_crops_and_labels, + save_images_and_labels, +) + + +PROCESSED_DATA_DIRNAME = ( + BaseDataModule.data_dirname() / "processed" / "iam_synthetic_paragraphs" +) + + +class IAMSyntheticParagraphs(IAMParagraphs): + """IAM Handwriting database of synthetic paragraphs.""" + + def __init__( + self, + batch_size: int = 128, + num_workers: int = 0, + train_fraction: float = 0.8, + augment: bool = True, + ) -> None: + super().__init__(batch_size, num_workers, train_fraction, augment) + + def prepare_data(self) -> None: + """Prepare IAM lines to be used to generate paragraphs.""" + if PROCESSED_DATA_DIRNAME.exists(): + return + + logger.info("Preparing IAM lines for synthetic paragraphs dataset.") + logger.info("Cropping IAM line regions and loading labels.") + + iam = IAM() + iam.prepare_data() + + crops_train, labels_train = line_crops_and_labels(iam, "train") + crops_test, labels_test = line_crops_and_labels(iam, "test") + + crops_train = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_train] + crops_test = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_test] + + logger.info(f"Saving images and labels at {PROCESSED_DATA_DIRNAME}") + save_images_and_labels( + crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME + ) + save_images_and_labels(crops_test, labels_test, "test", PROCESSED_DATA_DIRNAME) + + def setup(self, stage: str = None) -> None: + """Loading synthetic dataset.""" + + logger.info(f"IAM Synthetic dataset steup for stage {stage}") + + if stage == "fit" or stage is None: + line_crops, line_labels = load_line_crops_and_labels( + "train", PROCESSED_DATA_DIRNAME + ) + data, paragraphs_labels = generate_synthetic_paragraphs( + line_crops=line_crops, line_labels=line_labels + ) + + targets = convert_strings_to_labels( + strings=paragraphs_labels, + mapping=self.inverse_mapping, + length=self.output_dims[0], + ) + self.data_train = BaseDataset( + data, + targets, + transform=get_transform( + image_shape=self.dims[1:], augment=self.augment + ), + ) + + def __repr__(self) -> str: + """Return information about the dataset.""" + basic = ( + "IAM Synthetic Paragraphs Dataset\n" # pylint: disable=no-member + f"Num classes: {len(self.mapping)}\n" + f"Input dims : {self.dims}\n" + f"Output dims: {self.output_dims}\n" + ) + if self.data_train is None: + return basic + + x, y = next(iter(self.train_dataloader())) + data = ( + f"Train/val/test sizes: {len(self.data_train)}, 0, 0\n" + f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" + f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" + ) + return basic + data + + +def generate_synthetic_paragraphs( + line_crops: List[Image.Image], line_labels: List[str], max_batch_size: int = 9 +) -> Tuple[List[Image.Image], List[str]]: + """Generate synthetic paragraphs from randomly joining different subsets.""" + paragraphs_properties = get_dataset_properties() + + indices = list(range(len(line_labels))) + + if max_batch_size >= paragraphs_properties["num_lines"]["max"]: + raise ValueError("max_batch_size greater or equalt to max num lines.") + + batched_indices_list = [[index] for index in indices] + batched_indices_list.extend( + generate_random_batches( + values=indices, min_batch_size=2, max_batch_size=max_batch_size // 2 + ) + ) + batched_indices_list.extend( + generate_random_batches( + values=indices, min_batch_size=2, max_batch_size=max_batch_size + ) + ) + batched_indices_list.extend( + generate_random_batches( + values=indices, + min_batch_size=max_batch_size // 2 + 1, + max_batch_size=max_batch_size, + ) + ) + + paragraphs_crops, paragraphs_labels = [], [] + for paragraph_indices in batched_indices_list: + paragraph_label = NEW_LINE_TOKEN.join( + [line_labels[i] for i in paragraph_indices] + ) + if len(paragraph_label) > paragraphs_properties["label_length"]["max"]: + logger.info( + "Label longer than longest label in original IAM paragraph dataset - hence dropping." + ) + continue + + paragraph_crop = join_line_crops_to_form_paragraph( + [line_crops[i] for i in paragraph_indices] + ) + max_paragraph_shape = paragraphs_properties["crop_shape"]["max"] + + if ( + paragraph_crop.height > max_paragraph_shape[0] + or paragraph_crop.width > max_paragraph_shape[1] + ): + logger.info( + "Crop larger than largest crop in original IAM paragraphs dataset - hence dropping" + ) + continue + + paragraphs_crops.append(paragraph_crop) + paragraphs_labels.append(paragraph_label) + + if len(paragraphs_crops) != len(paragraphs_labels): + raise ValueError("Number of crops does not match number of labels.") + + return paragraphs_crops, paragraphs_labels + + +def join_line_crops_to_form_paragraph(line_crops: Sequence[Image.Image]) -> Image.Image: + """Horizontally stack line crops and return a single image forming a paragraph.""" + crop_shapes = np.array([line.size[::-1] for line in line_crops]) + paragraph_height = crop_shapes[:, 0].sum() + paragraph_width = crop_shapes[:, 1].max() + + paragraph_image = Image.new( + mode="L", size=(paragraph_width, paragraph_height), color=0 + ) + current_height = 0 + for line_crop in line_crops: + paragraph_image.paste(line_crop, box=(0, current_height)) + current_height += line_crop.height + + return paragraph_image + + +def generate_random_batches( + values: List[Any], min_batch_size: int, max_batch_size: int +) -> List[List[Any]]: + """Generate random batches of elements in values without replacement.""" + shuffled_values = values.copy() + random.shuffle(shuffled_values) + + start_index = 0 + grouped_values_list = [] + while start_index < len(shuffled_values): + num_values = random.randint(min_batch_size, max_batch_size) + grouped_values_list.append( + shuffled_values[start_index : start_index + num_values] + ) + start_index += num_values + + if sum([len(grp) for grp in grouped_values_list]) != len(values): + raise ValueError("Length of groups does not match length of values.") + + return grouped_values_list + + +def create_synthetic_iam_paragraphs() -> None: + load_and_print_info(IAMSyntheticParagraphs) |