From 30e3ae483c846418b04ed48f014a4af2cf9a0771 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 10 Oct 2021 18:03:11 +0200 Subject: Update transforms in datamodule/set --- text_recognizer/data/iam_paragraphs.py | 68 +++++++++------------------------- 1 file changed, 17 insertions(+), 51 deletions(-) (limited to 'text_recognizer/data/iam_paragraphs.py') diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 254c7f5..26674e0 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -8,7 +8,6 @@ from loguru import logger as log import numpy as np from PIL import Image, ImageOps import torchvision.transforms as T -from torchvision.transforms.functional import InterpolationMode from tqdm import tqdm from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info @@ -17,9 +16,9 @@ from text_recognizer.data.base_dataset import ( convert_strings_to_labels, split_dataset, ) -from text_recognizer.data.emnist_mapping import EmnistMapping from text_recognizer.data.iam import IAM -from text_recognizer.data.transforms import WordPiece +from text_recognizer.data.mappings.emnist_mapping import EmnistMapping +from text_recognizer.data.transforms.load_transform import load_transform_from_file PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_paragraphs" @@ -38,11 +37,6 @@ MAX_WORD_PIECE_LENGTH = 451 class IAMParagraphs(BaseDataModule): """IAM handwriting database paragraphs.""" - word_pieces: bool = attr.ib(default=False) - augment: bool = attr.ib(default=True) - train_fraction: float = attr.ib(default=0.8) - resize: Optional[Tuple[int, int]] = attr.ib(default=None) - # Placeholders dims: Tuple[int, int, int] = attr.ib( init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH) @@ -82,7 +76,7 @@ class IAMParagraphs(BaseDataModule): """Loads the data for training/testing.""" def _load_dataset( - split: str, augment: bool, resize: Optional[Tuple[int, int]] + split: str, transform: T.Compose, target_transform: T.Compose ) -> BaseDataset: crops, labels = _load_processed_crops_and_labels(split) data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops] @@ -92,12 +86,7 @@ class IAMParagraphs(BaseDataModule): length=self.output_dims[0], ) return BaseDataset( - data, - targets, - transform=get_transform( - image_shape=self.dims[1:], augment=augment, resize=resize - ), - target_transform=get_target_transform(self.word_pieces), + data, targets, transform=transform, target_transform=target_transform, ) log.info(f"Loading IAM paragraph regions and lines for {stage}...") @@ -105,7 +94,9 @@ class IAMParagraphs(BaseDataModule): if stage == "fit" or stage is None: data_train = _load_dataset( - split="train", augment=self.augment, resize=self.resize + split="train", + transform=self.transform, + target_transform=self.target_transform, ) self.data_train, self.data_val = split_dataset( dataset=data_train, fraction=self.train_fraction, seed=SEED @@ -113,7 +104,9 @@ class IAMParagraphs(BaseDataModule): if stage == "test" or stage is None: self.data_test = _load_dataset( - split="test", augment=False, resize=self.resize + split="test", + transform=self.test_transform, + target_transform=self.target_transform, ) def __repr__(self) -> str: @@ -130,6 +123,8 @@ class IAMParagraphs(BaseDataModule): x, y = next(iter(self.train_dataloader())) xt, yt = next(iter(self.test_dataloader())) + x = x[0] if isinstance(x, list) else x + xt = xt[0] if isinstance(xt, list) else xt data = ( "Train/val/test sizes: " f"{len(self.data_train)}, " @@ -274,39 +269,6 @@ def _load_processed_crops_and_labels( return ordered_crops, ordered_labels -def get_transform( - image_shape: Tuple[int, int], augment: bool, resize: Optional[Tuple[int, int]] -) -> T.Compose: - """Get transformations for images.""" - if augment: - transforms_list = [ - T.RandomCrop( - size=image_shape, - padding=None, - pad_if_needed=True, - fill=0, - padding_mode="constant", - ), - T.ColorJitter(brightness=(0.8, 1.6)), - T.RandomAffine( - degrees=1, shear=(-10, 10), interpolation=InterpolationMode.BILINEAR, - ), - ] - else: - transforms_list = [T.CenterCrop(image_shape)] - if resize is not None: - transforms_list.append(T.Resize(resize, T.InterpolationMode.BILINEAR)) - transforms_list.append(T.ToTensor()) - return T.Compose(transforms_list) - - -def get_target_transform( - word_pieces: bool, max_len: int = MAX_WORD_PIECE_LENGTH -) -> Optional[T.Compose]: - """Transform emnist characters to word pieces.""" - return T.Compose([WordPiece(max_len=max_len)]) if word_pieces else None - - def _labels_filename(split: str) -> Path: """Return filename of processed labels.""" return PROCESSED_DATA_DIRNAME / split / "_labels.json" @@ -324,4 +286,8 @@ def _num_lines(label: str) -> int: def create_iam_paragraphs() -> None: """Loads and displays dataset statistics.""" - load_and_print_info(IAMParagraphs) + transform = load_transform_from_file("transform/paragraphs.yaml") + test_transform = load_transform_from_file("test_transform/paragraphs_test.yaml") + load_and_print_info( + IAMParagraphs(transform=transform, test_transform=test_transform) + ) -- cgit v1.2.3-70-g09d2