diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-03 21:59:07 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-03 21:59:07 +0200 |
commit | 07f2cc3665a1a60efe8ed8073cad6ac4f344b2c2 (patch) | |
tree | d24ae8e3b9b39bfcfb3b850b30cb966eb3b064a7 /text_recognizer/data | |
parent | 3196144ec99e803cef218295ddea592748931c57 (diff) |
Add IAM paragraphs dataset
Diffstat (limited to 'text_recognizer/data')
-rw-r--r-- | text_recognizer/data/emnist.py | 10 | ||||
-rw-r--r-- | text_recognizer/data/iam_lines.py | 4 | ||||
-rw-r--r-- | text_recognizer/data/iam_paragraphs.py | 310 | ||||
-rw-r--r-- | text_recognizer/data/iam_preprocessor.py | 6 | ||||
-rw-r--r-- | text_recognizer/data/make_wordpieces.py | 2 | ||||
-rw-r--r-- | text_recognizer/data/transforms.py | 11 |
6 files changed, 330 insertions, 13 deletions
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 3e10b5f..eda490a 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -1,6 +1,6 @@ """EMNIST dataset: downloads it from FSDL aws url if not present.""" from pathlib import Path -from typing import Dict, List, Sequence, Tuple +from typing import Dict, List, Optional, Sequence, Tuple import json import os import shutil @@ -52,7 +52,7 @@ class EMNIST(BaseDataModule): self.data_val = None self.data_test = None self.transform = transforms.Compose([transforms.ToTensor()]) - self.dims = (1, * self.input_shape) + self.dims = (1, *self.input_shape) self.output_dims = (1,) def prepare_data(self) -> None: @@ -95,13 +95,17 @@ class EMNIST(BaseDataModule): return basic + data -def emnist_mapping() -> Tuple[List, Dict[str, int], List[int]]: +def emnist_mapping( + extra_symbols: Optional[List[str]], +) -> Tuple[List, Dict[str, int], List[int]]: """Return the EMNIST mapping.""" if not ESSENTIALS_FILENAME.exists(): download_and_process_emnist() with ESSENTIALS_FILENAME.open() as f: essentials = json.load(f) mapping = list(essentials["characters"]) + if extra_symbols is not None: + mapping += extra_symbols inverse_mapping = {v: k for k, v in enumerate(mapping)} input_shape = essentials["input_shape"] return mapping, inverse_mapping, input_shape diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index 391075a..78bc8e1 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -211,7 +211,9 @@ def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, Li def get_transform(image_width: int, augment: bool = False) -> transforms.Compose: """Augment with brigthness, sligth rotation, slant, translation, scale, and Gaussian noise.""" - def embed_crop(crop: Image, augment: bool = augment, image_width: int = image_width) -> Image: + def embed_crop( + crop: Image, augment: bool = augment, image_width: int = image_width + ) -> Image: # Crop is PIL.Image of dtype="L" (so value range is [0, 255]) image = Image.new("L", (image_width, IMAGE_HEIGHT)) diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py new file mode 100644 index 0000000..402a8d4 --- /dev/null +++ b/text_recognizer/data/iam_paragraphs.py @@ -0,0 +1,310 @@ +"""IAM Paragraphs Dataset class.""" +import json +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Tuple + +from loguru import logger +import numpy as np +from PIL import Image, ImageFile, ImageOps +import torch +import torchvision.transforms as transforms +from torchvision.transforms.functional import InterpolationMode +from tqdm import tqdm + +from text_recognizer.data.base_dataset import ( + BaseDataset, + convert_strings_to_labels, + split_dataset, +) +from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info +from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.data.iam import IAM + + +PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_paragraphs" + +NEW_LINE_TOKEN = "\n" + +SEED = 4711 +IMAGE_SCALE_FACTOR = 2 +IMAGE_HEIGHT = 1152 // IMAGE_SCALE_FACTOR +IMAGE_WIDTH = 1280 // IMAGE_SCALE_FACTOR +MAX_LABEL_LENGTH = 682 + + +class IAMParagraphs(BaseDataModule): + """IAM handwriting database paragraphs.""" + + def __init__( + self, + batch_size: int = 128, + num_workers: int = 0, + train_fraction: float = 0.8, + augment: bool = True, + ) -> None: + super().__init__(batch_size, num_workers) + # TODO: pass in transform and target transform + # TODO: pass in mapping + self.augment = augment + self.mapping, self.inverse_mapping, _ = emnist_mapping( + extra_symbols=[NEW_LINE_TOKEN] + ) + self.train_fraction = train_fraction + + self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) + self.output_dims = (MAX_LABEL_LENGTH, 1) + self.data_train: BaseDataset = None + self.data_val: BaseDataset = None + self.data_test: BaseDataset = None + + def prepare_data(self) -> None: + """Create data for training/testing.""" + if PROCESSED_DATA_DIRNAME.exists(): + return + + logger.info( + "Cropping IAM paragraph regions and saving them along with labels..." + ) + + iam = IAM() + iam.prepare_data() + + properties = {} + for split in ["train", "test"]: + crops, labels = _get_paragraph_crops_and_labels(iam=iam, split=split) + _save_crops_and_labels(crops=crops, labels=labels, split=split) + + properties.update( + { + id_: { + "crop_shape": crops[id_].size[::-1], + "label_length": len(label), + "num_lines": _num_lines(label), + } + for id_, label in labels.items() + } + ) + + with (PROCESSED_DATA_DIRNAME / "_properties.json").open("w") as f: + json.dump(properties, f, indent=4) + + def setup(self, stage: str = None) -> None: + """Loads the data for training/testing.""" + + def _load_dataset(split: str, augment: bool) -> BaseDataset: + crops, labels = _load_processed_crops_and_labels(split) + data = [_resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops] + targets = convert_strings_to_labels( + strings=labels, mapping=self.inverse_mapping, length=self.output_dims[0] + ) + return BaseDataset( + data, + targets, + transform=_get_transform(image_shape=self.dims[1:], augment=augment), + ) + + logger.info(f"Loading IAM paragraph regions and lines for {stage}...") + _validate_data_dims(input_dims=self.dims, output_dims=self.output_dims) + + if stage == "fit" or stage is None: + data_train = _load_dataset(split="train", augment=self.augment) + self.data_train, self.data_val = split_dataset( + dataset=data_train, fraction=self.train_fraction, seed=SEED + ) + + if stage == "test" or stage is None: + self.data_test = _load_dataset(split="test", augment=False) + + def __repr__(self) -> str: + """Return information about the dataset.""" + basic = ( + "IAM Paragraphs Dataset\n" + f"Num classes: {len(self.mapping)}\n" + f"Input dims: {self.dims}\n" + f"Output dims: {self.output_dims}\n" + ) + + if not any([self.data_train, self.data_val, self.data_test]): + return basic + + x, y = next(iter(self.train_dataloader())) + xt, yt = next(iter(self.test_dataloader())) + data = ( + "Train/val/test sizes: " + f"{len(self.data_train)}, " + f"{len(self.data_val)}, " + f"{len(self.data_test)}\n" + f"Train Batch x stats: {(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())}\n" + f"Train Batch y stats: {(y.shape, y.dtype, y.min(), y.max())}\n" + f"Test Batch x stats: {(xt.shape, xt.dtype, xt.min(), xt.mean(), xt.std(), xt.max())}\n" + f"Test Batch y stats: {(yt.shape, yt.dtype, yt.min(), yt.max())}\n" + ) + return basic + data + + +def _get_dataset_properties() -> Dict: + """Return properties describing the overall dataset.""" + with (PROCESSED_DATA_DIRNAME / "_properties.json").open("r") as f: + properties = json.load(f) + + def _get_property_values(key: str) -> List: + return [value[key] for value in properties.values()] + + crop_shapes = np.array(_get_property_values("crop_shape")) + aspect_ratio = crop_shapes[:, 1] / crop_shapes[:, 0] + return { + "label_length": { + "min": min(_get_property_values("label_length")), + "max": max(_get_property_values("label_length")), + }, + "num_lines": { + "min": min(_get_property_values("num_lines")), + "max": max(_get_property_values("num_lines")), + }, + "crop_shape": {"min": crop_shapes.min(axis=0), "max": crop_shapes.max(axis=0),}, + "aspect_ratio": { + "min": aspect_ratio.min(axis=0), + "max": aspect_ratio.max(axis=0), + }, + } + + +def _validate_data_dims( + input_dims: Optional[Tuple[int, ...]], output_dims: Optional[Tuple[int, ...]] +) -> None: + """Validates input and output dimensions against the properties of the dataset.""" + properties = _get_dataset_properties() + + max_image_shape = properties["crop_shape"]["max"] / IMAGE_SCALE_FACTOR + if ( + input_dims is not None + and input_dims[1] < max_image_shape[0] + and input_dims[2] < max_image_shape[1] + ): + raise ValueError(f"{input_dims} less than {max_image_shape}") + + if ( + output_dims is not None + and output_dims[0] < properties["label_length"]["max"] + 2 + ): + raise ValueError( + f"{output_dims} less than {properties['label_length']['max'] + 2}" + ) + + +def _resize_image(image: Image.Image, scale_factor: int) -> Image.Image: + """Resize image by scale factor.""" + if scale_factor == 1: + return image + return image.resize( + (image.width // scale_factor, image.height // scale_factor), + resample=Image.BILINEAR, + ) + + +def _get_paragraph_crops_and_labels( + iam: IAM, split: str +) -> Tuple[Dict[str, Image.Image], Dict[str, str]]: + """Load IAM paragraph crops and labels for a given set.""" + crops = {} + labels = {} + for form_filename in tqdm( + iam.form_filenames, desc=f"Processing {split} paragraphs" + ): + id_ = form_filename.stem + if not iam.split_by_id[id_] == split: + continue + image = Image.open(form_filename) + image = ImageOps.grayscale(image) + image = ImageOps.invert(image) + + line_regions = iam.line_regions_by_id[id_] + parameter_box = [ + min([region["x1"] for region in line_regions]), + min([region["y1"] for region in line_regions]), + max([region["x2"] for region in line_regions]), + max([region["y2"] for region in line_regions]), + ] + lines = iam.line_strings_by_id[id_] + + crops[id_] = image.crop(parameter_box) + labels[id_] = NEW_LINE_TOKEN.join(lines) + + if len(crops) != len(labels): + raise ValueError(f"Crops ({len(crops)}) does not match labels ({len(labels)})") + + return crops, labels + + +def _save_crops_and_labels( + crops: Dict[str, Image.Image], labels: Dict[str, str], split: str +) -> None: + """Save crops, labels, and shapes of crops of a split.""" + (PROCESSED_DATA_DIRNAME / split).mkdir(parents=True, exist_ok=True) + + with _labels_filename(split).open("w") as f: + json.dump(labels, f, indent=4) + + for id_, crop in crops.items(): + crop.save(_crop_filename(id_, split)) + + +def _load_processed_crops_and_labels( + split: str, +) -> Tuple[Sequence[Image.Image], Sequence[str]]: + """Load processed crops and labels for given split.""" + with _labels_filename(split).open("r") as f: + labels = json.load(f) + + sorted_ids = sorted(labels.keys()) + ordered_crops = [ + Image.open(_crop_filename(id_, split)).convert("L") for id_ in sorted_ids + ] + ordered_labels = [labels[id_] for id_ in sorted_ids] + + if len(ordered_crops) != len(ordered_labels): + raise ValueError( + f"Crops ({len(ordered_crops)}) does not match labels ({len(ordered_labels)})" + ) + return ordered_crops, ordered_labels + + +def _get_transform(image_shape: Tuple[int, int], augment: bool) -> transforms.Compose: + """Get transformations for images.""" + if augment: + transforms_list = [ + transforms.RandomCrop( + size=image_shape, + padding=None, + pad_if_needed=True, + fill=0, + padding_mode="constant", + ), + transforms.ColorJitter(brightness=(0.8, 1.6)), + transforms.RandomAffine( + degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR, + ), + ] + else: + transforms_list = [transforms.CenterCrop(image_shape)] + transforms_list.append(transforms.ToTensor()) + return transforms.Compose(transforms_list) + + +def _labels_filename(split: str) -> Path: + """Return filename of processed labels.""" + return PROCESSED_DATA_DIRNAME / split / "_labels.json" + + +def _crop_filename(id: str, split: str) -> Path: + """Return filename of processed crop.""" + return PROCESSED_DATA_DIRNAME / split / f"{id}.png" + + +def _num_lines(label: str) -> int: + """Return the number of lines of text in label.""" + return label.count("\n") + 1 + + +def create_iam_paragraphs() -> None: + load_and_print_info(IAMParagraphs) diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py index a47aeed..3844419 100644 --- a/text_recognizer/data/iam_preprocessor.py +++ b/text_recognizer/data/iam_preprocessor.py @@ -166,7 +166,11 @@ def cli( """CLI for extracting text data from the iam dataset.""" if data_dir is None: data_dir = ( - Path(__file__).resolve().parents[2] / "data" / "downloaded" / "iam" / "iamdb" + Path(__file__).resolve().parents[2] + / "data" + / "downloaded" + / "iam" + / "iamdb" ) logger.debug(f"Using data dir: {data_dir}") if not data_dir.exists(): diff --git a/text_recognizer/data/make_wordpieces.py b/text_recognizer/data/make_wordpieces.py index e062c4c..ef9eb1b 100644 --- a/text_recognizer/data/make_wordpieces.py +++ b/text_recognizer/data/make_wordpieces.py @@ -99,7 +99,7 @@ def cli( """CLI for training the sentence piece model.""" if data_dir is None: data_dir = ( - Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines" + Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines" ) logger.debug(f"Using data dir: {data_dir}") if not data_dir.exists(): diff --git a/text_recognizer/data/transforms.py b/text_recognizer/data/transforms.py index 2291eec..616e236 100644 --- a/text_recognizer/data/transforms.py +++ b/text_recognizer/data/transforms.py @@ -9,7 +9,8 @@ from torch import Tensor from text_recognizer.datasets.iam_preprocessor import Preprocessor from text_recognizer.data.emnist import emnist_mapping - + + class ToLower: """Converts target to lower case.""" @@ -23,15 +24,11 @@ class ToCharcters: """Converts integers to characters.""" def __init__(self) -> None: - self.mapping, _, _ = emnist_mapping() + self.mapping, _, _ = emnist_mapping() def __call__(self, y: Tensor) -> str: """Converts a Tensor to a str.""" - return ( - "".join([self.mapping(int(i)) for i in y]) - .strip("<p>") - .replace(" ", "▁") - ) + return "".join([self.mapping(int(i)) for i in y]).strip("<p>").replace(" ", "▁") class WordPieces: |