From 30e3ae483c846418b04ed48f014a4af2cf9a0771 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 10 Oct 2021 18:03:11 +0200 Subject: Update transforms in datamodule/set --- text_recognizer/data/iam_lines.py | 75 +++++++-------------------------------- 1 file changed, 12 insertions(+), 63 deletions(-) (limited to 'text_recognizer/data/iam_lines.py') diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index 7a063c1..efd1cde 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -5,7 +5,6 @@ dataset. """ import json from pathlib import Path -import random from typing import List, Sequence, Tuple import attr @@ -13,19 +12,17 @@ from loguru import logger as log import numpy as np from PIL import Image, ImageFile, ImageOps from torch import Tensor -import torchvision.transforms as T -from torchvision.transforms.functional import InterpolationMode -from text_recognizer.data import image_utils from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.base_dataset import ( BaseDataset, convert_strings_to_labels, split_dataset, ) -from text_recognizer.data.emnist_mapping import EmnistMapping from text_recognizer.data.iam import IAM -from text_recognizer.data.iam_paragraphs import get_target_transform +from text_recognizer.data.mappings.emnist_mapping import EmnistMapping +from text_recognizer.data.utils import image_utils +from text_recognizer.data.transforms.load_transform import load_transform_from_file ImageFile.LOAD_TRUNCATED_IMAGES = True @@ -42,9 +39,6 @@ MAX_WORD_PIECE_LENGTH = 72 class IAMLines(BaseDataModule): """IAM handwritten lines dataset.""" - word_pieces: bool = attr.ib(default=False) - augment: bool = attr.ib(default=True) - train_fraction: float = attr.ib(default=0.8) dims: Tuple[int, int, int] = attr.ib( init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH) ) @@ -94,10 +88,8 @@ class IAMLines(BaseDataModule): data_train = BaseDataset( x_train, y_train, - transform=get_transform(IMAGE_WIDTH, self.augment), - target_transform=get_target_transform( - self.word_pieces, max_len=MAX_WORD_PIECE_LENGTH - ), + transform=self.transform, + target_transform=self.target_transform, ) self.data_train, self.data_val = split_dataset( @@ -118,10 +110,8 @@ class IAMLines(BaseDataModule): self.data_test = BaseDataset( x_test, y_test, - transform=get_transform(IMAGE_WIDTH), - target_transform=get_target_transform( - self.word_pieces, max_len=MAX_WORD_PIECE_LENGTH - ), + transform=self.test_transform, + target_transform=self.target_transform, ) if stage is None: @@ -147,6 +137,8 @@ class IAMLines(BaseDataModule): x, y = next(iter(self.train_dataloader())) xt, yt = next(iter(self.test_dataloader())) + x = x[0] if isinstance(x, list) else x + xt = xt[0] if isinstance(xt, list) else xt data = ( "Train/val/test sizes: " f"{len(self.data_train)}, " @@ -217,51 +209,8 @@ def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, Li return crops, labels -def get_transform(image_width: int, augment: bool = False) -> T.Compose: - """Augment with brigthness, rotation, slant, translation, scale, and noise.""" - - def embed_crop( - crop: Image, augment: bool = augment, image_width: int = image_width - ) -> Image: - # Crop is PIL.Image of dtype="L" (so value range is [0, 255]) - image = Image.new("L", (image_width, IMAGE_HEIGHT)) - - # Resize crop. - crop_width, crop_height = crop.size - new_crop_height = IMAGE_HEIGHT - new_crop_width = int(new_crop_height * (crop_width / crop_height)) - - if augment: - # Add random stretching - new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1)) - new_crop_width = min(new_crop_width, image_width) - crop_resized = crop.resize( - (new_crop_width, new_crop_height), resample=Image.BILINEAR - ) - - # Embed in image - x = min(28, image_width - new_crop_width) - y = IMAGE_HEIGHT - new_crop_height - image.paste(crop_resized, (x, y)) - - return image - - transfroms_list = [T.Lambda(embed_crop)] - - if augment: - transfroms_list += [ - T.ColorJitter(brightness=(0.8, 1.6)), - T.RandomAffine( - degrees=1, - shear=(-30, 20), - interpolation=InterpolationMode.BILINEAR, - fill=0, - ), - ] - transfroms_list.append(T.ToTensor()) - return T.Compose(transfroms_list) - - def generate_iam_lines() -> None: """Displays Iam Lines dataset statistics.""" - load_and_print_info(IAMLines) + transform = load_transform_from_file("transform/iam_lines.yaml") + test_transform = load_transform_from_file("test_transform/iam_lines.yaml") + load_and_print_info(IAMLines(transform=transform, test_transform=test_transform)) -- cgit v1.2.3-70-g09d2