summaryrefslogtreecommitdiff
path: root/text_recognizer/data
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data')
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py75
-rw-r--r--text_recognizer/data/iam_paragraphs.py16
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py224
3 files changed, 307 insertions, 8 deletions
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
new file mode 100644
index 0000000..51050fc
--- /dev/null
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -0,0 +1,75 @@
+"""IAM original and sythetic dataset class."""
+from torch.utils.data import ConcatDataset
+
+from text_recognizer.data.base_dataset import BaseDataset
+from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
+from text_recognizer.data.iam_paragraphs import IAMParagraphs
+from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
+
+
+class IAMExtendedParagraphs(BaseDataModule):
+ def __init__(
+ self,
+ batch_size: int = 128,
+ num_workers: int = 0,
+ train_fraction: float = 0.8,
+ augment: bool = True,
+ ) -> None:
+ super().__init__(batch_size, num_workers)
+
+ self.iam_paragraphs = IAMParagraphs(
+ batch_size, num_workers, train_fraction, augment,
+ )
+ self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
+ batch_size, num_workers, train_fraction, augment,
+ )
+
+ self.dims = self.iam_paragraphs.dims
+ self.output_dims = self.iam_paragraphs.output_dims
+ self.mapping = self.iam_paragraphs.mapping
+ self.inverse_mapping = self.iam_paragraphs.inverse_mapping
+
+ self.data_train: BaseDataset = None
+ self.data_val: BaseDataset = None
+ self.data_test: BaseDataset = None
+
+ def prepare_data(self) -> None:
+ """Prepares the paragraphs data."""
+ self.iam_paragraphs.prepare_data()
+ self.iam_synthetic_paragraphs.prepare_data()
+
+ def setup(self, stage: str = None) -> None:
+ """Loads data for training/testing."""
+ self.iam_paragraphs.setup(stage)
+ self.iam_synthetic_paragraphs.setup(stage)
+
+ self.data_train = ConcatDataset(
+ [self.iam_paragraphs.data_train, self.iam_synthetic_paragraphs.data_train]
+ )
+ self.data_val = self.iam_paragraphs.data_val
+ self.data_test = self.iam_paragraphs.data_test
+
+ def __repr__(self) -> str:
+ """Returns info about the dataset."""
+ basic = (
+ "IAM Original and Synthetic Paragraphs Dataset\n" # pylint: disable=no-member
+ f"Num classes: {len(self.mapping)}\n"
+ f"Dims: {self.dims}\n"
+ f"Output dims: {self.output_dims}\n"
+ )
+ if self.data_train is None and self.data_val is None and self.data_test is None:
+ return basic
+
+ x, y = next(iter(self.train_dataloader()))
+ xt, yt = next(iter(self.test_dataloader()))
+ data = (
+ f"Train/val/test sizes: {len(self.data_train)}, {len(self.data_val)}, {len(self.data_test)}\n"
+ f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
+ f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
+ f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n"
+ f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n"
+ )
+ return basic + data
+
+def show_dataset_info() -> None:
+ load_and_print_info(IAMExtendedParagraphs)
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 402a8d4..f588587 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -93,14 +93,14 @@ class IAMParagraphs(BaseDataModule):
def _load_dataset(split: str, augment: bool) -> BaseDataset:
crops, labels = _load_processed_crops_and_labels(split)
- data = [_resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops]
+ data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops]
targets = convert_strings_to_labels(
strings=labels, mapping=self.inverse_mapping, length=self.output_dims[0]
)
return BaseDataset(
data,
targets,
- transform=_get_transform(image_shape=self.dims[1:], augment=augment),
+ transform=get_transform(image_shape=self.dims[1:], augment=augment),
)
logger.info(f"Loading IAM paragraph regions and lines for {stage}...")
@@ -142,7 +142,7 @@ class IAMParagraphs(BaseDataModule):
return basic + data
-def _get_dataset_properties() -> Dict:
+def get_dataset_properties() -> Dict:
"""Return properties describing the overall dataset."""
with (PROCESSED_DATA_DIRNAME / "_properties.json").open("r") as f:
properties = json.load(f)
@@ -173,7 +173,7 @@ def _validate_data_dims(
input_dims: Optional[Tuple[int, ...]], output_dims: Optional[Tuple[int, ...]]
) -> None:
"""Validates input and output dimensions against the properties of the dataset."""
- properties = _get_dataset_properties()
+ properties = get_dataset_properties()
max_image_shape = properties["crop_shape"]["max"] / IMAGE_SCALE_FACTOR
if (
@@ -192,7 +192,7 @@ def _validate_data_dims(
)
-def _resize_image(image: Image.Image, scale_factor: int) -> Image.Image:
+def resize_image(image: Image.Image, scale_factor: int) -> Image.Image:
"""Resize image by scale factor."""
if scale_factor == 1:
return image
@@ -219,7 +219,7 @@ def _get_paragraph_crops_and_labels(
image = ImageOps.invert(image)
line_regions = iam.line_regions_by_id[id_]
- parameter_box = [
+ paragraph_box = [
min([region["x1"] for region in line_regions]),
min([region["y1"] for region in line_regions]),
max([region["x2"] for region in line_regions]),
@@ -227,7 +227,7 @@ def _get_paragraph_crops_and_labels(
]
lines = iam.line_strings_by_id[id_]
- crops[id_] = image.crop(parameter_box)
+ crops[id_] = image.crop(paragraph_box)
labels[id_] = NEW_LINE_TOKEN.join(lines)
if len(crops) != len(labels):
@@ -269,7 +269,7 @@ def _load_processed_crops_and_labels(
return ordered_crops, ordered_labels
-def _get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Compose:
+def get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Compose:
"""Get transformations for images."""
if augment:
transforms_list = [
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
new file mode 100644
index 0000000..9f1bd12
--- /dev/null
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -0,0 +1,224 @@
+"""IAM Synthetic Paragraphs Dataset class."""
+import itertools
+from pathlib import Path
+import random
+import time
+from typing import Any, List, Sequence, Tuple
+
+from loguru import logger
+import numpy as np
+from PIL import Image
+
+from text_recognizer.data.base_dataset import (
+ BaseDataset,
+ convert_strings_to_labels,
+ split_dataset,
+)
+from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
+from text_recognizer.data.iam_paragraphs import (
+ get_dataset_properties,
+ get_transform,
+ NEW_LINE_TOKEN,
+ IAMParagraphs,
+ IMAGE_SCALE_FACTOR,
+ resize_image,
+)
+from text_recognizer.data.iam import IAM
+from text_recognizer.data.iam_lines import (
+ line_crops_and_labels,
+ load_line_crops_and_labels,
+ save_images_and_labels,
+)
+
+
+PROCESSED_DATA_DIRNAME = (
+ BaseDataModule.data_dirname() / "processed" / "iam_synthetic_paragraphs"
+)
+
+
+class IAMSyntheticParagraphs(IAMParagraphs):
+ """IAM Handwriting database of synthetic paragraphs."""
+
+ def __init__(
+ self,
+ batch_size: int = 128,
+ num_workers: int = 0,
+ train_fraction: float = 0.8,
+ augment: bool = True,
+ ) -> None:
+ super().__init__(batch_size, num_workers, train_fraction, augment)
+
+ def prepare_data(self) -> None:
+ """Prepare IAM lines to be used to generate paragraphs."""
+ if PROCESSED_DATA_DIRNAME.exists():
+ return
+
+ logger.info("Preparing IAM lines for synthetic paragraphs dataset.")
+ logger.info("Cropping IAM line regions and loading labels.")
+
+ iam = IAM()
+ iam.prepare_data()
+
+ crops_train, labels_train = line_crops_and_labels(iam, "train")
+ crops_test, labels_test = line_crops_and_labels(iam, "test")
+
+ crops_train = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_train]
+ crops_test = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_test]
+
+ logger.info(f"Saving images and labels at {PROCESSED_DATA_DIRNAME}")
+ save_images_and_labels(
+ crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME
+ )
+ save_images_and_labels(crops_test, labels_test, "test", PROCESSED_DATA_DIRNAME)
+
+ def setup(self, stage: str = None) -> None:
+ """Loading synthetic dataset."""
+
+ logger.info(f"IAM Synthetic dataset steup for stage {stage}")
+
+ if stage == "fit" or stage is None:
+ line_crops, line_labels = load_line_crops_and_labels(
+ "train", PROCESSED_DATA_DIRNAME
+ )
+ data, paragraphs_labels = generate_synthetic_paragraphs(
+ line_crops=line_crops, line_labels=line_labels
+ )
+
+ targets = convert_strings_to_labels(
+ strings=paragraphs_labels,
+ mapping=self.inverse_mapping,
+ length=self.output_dims[0],
+ )
+ self.data_train = BaseDataset(
+ data,
+ targets,
+ transform=get_transform(
+ image_shape=self.dims[1:], augment=self.augment
+ ),
+ )
+
+ def __repr__(self) -> str:
+ """Return information about the dataset."""
+ basic = (
+ "IAM Synthetic Paragraphs Dataset\n" # pylint: disable=no-member
+ f"Num classes: {len(self.mapping)}\n"
+ f"Input dims : {self.dims}\n"
+ f"Output dims: {self.output_dims}\n"
+ )
+ if self.data_train is None:
+ return basic
+
+ x, y = next(iter(self.train_dataloader()))
+ data = (
+ f"Train/val/test sizes: {len(self.data_train)}, 0, 0\n"
+ f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n"
+ f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n"
+ )
+ return basic + data
+
+
+def generate_synthetic_paragraphs(
+ line_crops: List[Image.Image], line_labels: List[str], max_batch_size: int = 9
+) -> Tuple[List[Image.Image], List[str]]:
+ """Generate synthetic paragraphs from randomly joining different subsets."""
+ paragraphs_properties = get_dataset_properties()
+
+ indices = list(range(len(line_labels)))
+
+ if max_batch_size >= paragraphs_properties["num_lines"]["max"]:
+ raise ValueError("max_batch_size greater or equalt to max num lines.")
+
+ batched_indices_list = [[index] for index in indices]
+ batched_indices_list.extend(
+ generate_random_batches(
+ values=indices, min_batch_size=2, max_batch_size=max_batch_size // 2
+ )
+ )
+ batched_indices_list.extend(
+ generate_random_batches(
+ values=indices, min_batch_size=2, max_batch_size=max_batch_size
+ )
+ )
+ batched_indices_list.extend(
+ generate_random_batches(
+ values=indices,
+ min_batch_size=max_batch_size // 2 + 1,
+ max_batch_size=max_batch_size,
+ )
+ )
+
+ paragraphs_crops, paragraphs_labels = [], []
+ for paragraph_indices in batched_indices_list:
+ paragraph_label = NEW_LINE_TOKEN.join(
+ [line_labels[i] for i in paragraph_indices]
+ )
+ if len(paragraph_label) > paragraphs_properties["label_length"]["max"]:
+ logger.info(
+ "Label longer than longest label in original IAM paragraph dataset - hence dropping."
+ )
+ continue
+
+ paragraph_crop = join_line_crops_to_form_paragraph(
+ [line_crops[i] for i in paragraph_indices]
+ )
+ max_paragraph_shape = paragraphs_properties["crop_shape"]["max"]
+
+ if (
+ paragraph_crop.height > max_paragraph_shape[0]
+ or paragraph_crop.width > max_paragraph_shape[1]
+ ):
+ logger.info(
+ "Crop larger than largest crop in original IAM paragraphs dataset - hence dropping"
+ )
+ continue
+
+ paragraphs_crops.append(paragraph_crop)
+ paragraphs_labels.append(paragraph_label)
+
+ if len(paragraphs_crops) != len(paragraphs_labels):
+ raise ValueError("Number of crops does not match number of labels.")
+
+ return paragraphs_crops, paragraphs_labels
+
+
+def join_line_crops_to_form_paragraph(line_crops: Sequence[Image.Image]) -> Image.Image:
+ """Horizontally stack line crops and return a single image forming a paragraph."""
+ crop_shapes = np.array([line.size[::-1] for line in line_crops])
+ paragraph_height = crop_shapes[:, 0].sum()
+ paragraph_width = crop_shapes[:, 1].max()
+
+ paragraph_image = Image.new(
+ mode="L", size=(paragraph_width, paragraph_height), color=0
+ )
+ current_height = 0
+ for line_crop in line_crops:
+ paragraph_image.paste(line_crop, box=(0, current_height))
+ current_height += line_crop.height
+
+ return paragraph_image
+
+
+def generate_random_batches(
+ values: List[Any], min_batch_size: int, max_batch_size: int
+) -> List[List[Any]]:
+ """Generate random batches of elements in values without replacement."""
+ shuffled_values = values.copy()
+ random.shuffle(shuffled_values)
+
+ start_index = 0
+ grouped_values_list = []
+ while start_index < len(shuffled_values):
+ num_values = random.randint(min_batch_size, max_batch_size)
+ grouped_values_list.append(
+ shuffled_values[start_index : start_index + num_values]
+ )
+ start_index += num_values
+
+ if sum([len(grp) for grp in grouped_values_list]) != len(values):
+ raise ValueError("Length of groups does not match length of values.")
+
+ return grouped_values_list
+
+
+def create_synthetic_iam_paragraphs() -> None:
+ load_and_print_info(IAMSyntheticParagraphs)