From 9b8e14d89f0ef2508ed11f994f73af624155fe1d Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 27 Sep 2022 01:44:49 +0200 Subject: Update data modules --- text_recognizer/data/base_dataset.py | 16 +---- text_recognizer/data/emnist.py | 2 +- text_recognizer/data/emnist_lines.py | 8 +-- text_recognizer/data/iam.py | 2 +- text_recognizer/data/iam_extended_paragraphs.py | 15 +++-- text_recognizer/data/iam_lines.py | 7 +- text_recognizer/data/iam_paragraphs.py | 16 +++-- text_recognizer/data/iam_synthetic_paragraphs.py | 56 ++++++++------- text_recognizer/data/stems/__init__.py | 0 text_recognizer/data/stems/image.py | 18 +++++ text_recognizer/data/stems/line.py | 86 ++++++++++++++++++++++++ text_recognizer/data/stems/paragraph.py | 66 ++++++++++++++++++ 12 files changed, 235 insertions(+), 57 deletions(-) create mode 100644 text_recognizer/data/stems/__init__.py create mode 100644 text_recognizer/data/stems/image.py create mode 100644 text_recognizer/data/stems/line.py create mode 100644 text_recognizer/data/stems/paragraph.py (limited to 'text_recognizer/data') diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index 4ceb818..b840bc8 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -5,8 +5,6 @@ import torch from torch import Tensor from torch.utils.data import Dataset -from text_recognizer.data.transforms.load_transform import load_transform_from_file - class BaseDataset(Dataset): r"""Base Dataset class that processes data and targets through optional transfroms. @@ -23,8 +21,8 @@ class BaseDataset(Dataset): self, data: Union[Sequence, Tensor], targets: Union[Sequence, Tensor], - transform: Union[Optional[Callable], str], - target_transform: Union[Optional[Callable], str], + transform: Callable, + target_transform: Callable, ) -> None: super().__init__() @@ -34,16 +32,6 @@ class BaseDataset(Dataset): self.target_transform = target_transform if len(self.data) != len(self.targets): raise ValueError("Data and targets must be of equal length.") - self.transform = self._load_transform(self.transform) - self.target_transform = self._load_transform(self.target_transform) - - @staticmethod - def _load_transform( - transform: Union[Optional[Callable], str] - ) -> Optional[Callable]: - if isinstance(transform, str): - return load_transform_from_file(transform) - return transform def __len__(self) -> int: """Return the length of the dataset.""" diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 9c5727f..143705e 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -15,7 +15,7 @@ from text_recognizer.data.base_data_module import BaseDataModule, load_and_print from text_recognizer.data.base_dataset import BaseDataset, split_dataset from text_recognizer.data.transforms.load_transform import load_transform_from_file from text_recognizer.data.utils.download_utils import download_dataset -from text_recognizer.metadata import emnist as metadata +import text_recognizer.metadata.emnist as metadata class EMNIST(BaseDataModule): diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 63c9f22..88aac0d 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -13,9 +13,9 @@ from text_recognizer.data.base_data_module import BaseDataModule, load_and_print from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels from text_recognizer.data.emnist import EMNIST from text_recognizer.data.mappings import EmnistMapping -from text_recognizer.data.transforms.load_transform import load_transform_from_file +from text_recognizer.data.stems.line import LineStem from text_recognizer.data.utils.sentence_generator import SentenceGenerator -from text_recognizer.metadata import emnist_lines as metadata +import text_recognizer.metadata.emnist_lines as metadata class EMNISTLines(BaseDataModule): @@ -250,6 +250,6 @@ def _create_dataset_of_images( def generate_emnist_lines() -> None: """Generates a synthetic handwritten dataset and displays info.""" - transform = load_transform_from_file("transform/emnist_lines.yaml") - test_transform = load_transform_from_file("test_transform/default.yaml") + transform = LineStem(augment=False) + test_transform = LineStem(augment=False) load_and_print_info(EMNISTLines(transform=transform, test_transform=test_transform)) diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index c20b50b..2ce1e9c 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -15,7 +15,7 @@ from loguru import logger as log from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.utils.download_utils import download_dataset -from text_recognizer.metadata import iam as metadata +import text_recognizer.metadata.iam as metadata class IAM(BaseDataModule): diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 3ec8221..658626c 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -6,8 +6,10 @@ from torch.utils.data import ConcatDataset from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam_paragraphs import IAMParagraphs from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs +from text_recognizer.data.transforms.pad import Pad from text_recognizer.data.mappings import EmnistMapping -from text_recognizer.data.transforms.load_transform import load_transform_from_file +from text_recognizer.data.stems.paragraph import ParagraphStem +import text_recognizer.metadata.iam_paragraphs as metadata class IAMExtendedParagraphs(BaseDataModule): @@ -104,8 +106,13 @@ class IAMExtendedParagraphs(BaseDataModule): def show_dataset_info() -> None: """Displays Iam extended dataset information.""" - transform = load_transform_from_file("transform/paragraphs.yaml") - test_transform = load_transform_from_file("test_transform/paragraphs_test.yaml") + transform = ParagraphStem(augment=False) + test_transform = ParagraphStem(augment=False) + target_transform = Pad(metadata.MAX_LABEL_LENGTH, 3) load_and_print_info( - IAMExtendedParagraphs(transform=transform, test_transform=test_transform) + IAMExtendedParagraphs( + transform=transform, + test_transform=test_transform, + target_transform=target_transform, + ) ) diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index 3bb189c..e60d1ba 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -21,8 +21,9 @@ from text_recognizer.data.base_dataset import ( from text_recognizer.data.iam import IAM from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.transforms.load_transform import load_transform_from_file +from text_recognizer.data.stems.line import IamLinesStem from text_recognizer.data.utils import image_utils -from text_recognizer.metadata import iam_lines as metadata +import text_recognizer.metadata.iam_lines as metadata ImageFile.LOAD_TRUNCATED_IMAGES = True @@ -227,6 +228,6 @@ def load_line_crops_and_labels(split: str, data_dirname: Path) -> Tuple[List, Li def generate_iam_lines() -> None: """Displays Iam Lines dataset statistics.""" - transform = load_transform_from_file("transform/lines.yaml") - test_transform = load_transform_from_file("test_transform/lines.yaml") + transform = IamLinesStem() + test_transform = IamLinesStem() load_and_print_info(IAMLines(transform=transform, test_transform=test_transform)) diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index eec1b1f..fe1f15c 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -16,9 +16,10 @@ from text_recognizer.data.base_dataset import ( split_dataset, ) from text_recognizer.data.iam import IAM +from text_recognizer.data.transforms.pad import Pad from text_recognizer.data.mappings import EmnistMapping -from text_recognizer.data.transforms.load_transform import load_transform_from_file -from text_recognizer.metadata import iam_paragraphs as metadata +from text_recognizer.data.stems.paragraph import ParagraphStem +import text_recognizer.metadata.iam_paragraphs as metadata class IAMParagraphs(BaseDataModule): @@ -294,8 +295,13 @@ def _num_lines(label: str) -> int: def create_iam_paragraphs() -> None: """Loads and displays dataset statistics.""" - transform = load_transform_from_file("transform/paragraphs.yaml") - test_transform = load_transform_from_file("test_transform/paragraphs_test.yaml") + transform = ParagraphStem() + test_transform = ParagraphStem() + target_transform = Pad(metadata.MAX_LABEL_LENGTH, 3) load_and_print_info( - IAMParagraphs(transform=transform, test_transform=test_transform) + IAMParagraphs( + transform=transform, + test_transform=test_transform, + target_transform=target_transform, + ) ) diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index 52ed398..91fda4a 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -9,25 +9,20 @@ from PIL import Image from text_recognizer.data.base_data_module import load_and_print_info from text_recognizer.data.base_dataset import BaseDataset, convert_strings_to_labels from text_recognizer.data.iam import IAM -from text_recognizer.data.iam_lines import ( - line_crops_and_labels, - load_line_crops_and_labels, - save_images_and_labels, -) from text_recognizer.data.iam_paragraphs import ( - IMAGE_SCALE_FACTOR, - NEW_LINE_TOKEN, IAMParagraphs, get_dataset_properties, resize_image, ) -from text_recognizer.data.mappings import EmnistMapping -from text_recognizer.data.transforms.load_transform import load_transform_from_file -from text_recognizer.metadata import shared as metadata - -PROCESSED_DATA_DIRNAME = ( - metadata.DATA_DIRNAME / "processed" / "iam_synthetic_paragraphs" +from text_recognizer.data.iam_lines import ( + line_crops_and_labels, + load_line_crops_and_labels, + save_images_and_labels, ) +from text_recognizer.data.mappings import EmnistMapping +from text_recognizer.data.stems.paragraph import ParagraphStem +from text_recognizer.data.transforms.pad import Pad +import text_recognizer.metadata.iam_synthetic_paragraphs as metadata class IAMSyntheticParagraphs(IAMParagraphs): @@ -57,26 +52,32 @@ class IAMSyntheticParagraphs(IAMParagraphs): def prepare_data(self) -> None: """Prepare IAM lines to be used to generate paragraphs.""" - if PROCESSED_DATA_DIRNAME.exists(): + if metadata.PROCESSED_DATA_DIRNAME.exists(): return log.info("Preparing IAM lines for synthetic paragraphs dataset.") log.info("Cropping IAM line regions and loading labels.") - iam = IAM(mapping=EmnistMapping(extra_symbols=(NEW_LINE_TOKEN,))) + iam = IAM(mapping=EmnistMapping(extra_symbols=(metadata.NEW_LINE_TOKEN,))) iam.prepare_data() crops_train, labels_train = line_crops_and_labels(iam, "train") crops_test, labels_test = line_crops_and_labels(iam, "test") - crops_train = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_train] - crops_test = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_test] + crops_train = [ + resize_image(crop, metadata.IMAGE_SCALE_FACTOR) for crop in crops_train + ] + crops_test = [ + resize_image(crop, metadata.IMAGE_SCALE_FACTOR) for crop in crops_test + ] - log.info(f"Saving images and labels at {PROCESSED_DATA_DIRNAME}") + log.info(f"Saving images and labels at {metadata.PROCESSED_DATA_DIRNAME}") save_images_and_labels( - crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME + crops_train, labels_train, "train", metadata.PROCESSED_DATA_DIRNAME + ) + save_images_and_labels( + crops_test, labels_test, "test", metadata.PROCESSED_DATA_DIRNAME ) - save_images_and_labels(crops_test, labels_test, "test", PROCESSED_DATA_DIRNAME) def setup(self, stage: str = None) -> None: """Loading synthetic dataset.""" @@ -85,7 +86,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): if stage == "fit" or stage is None: line_crops, line_labels = load_line_crops_and_labels( - "train", PROCESSED_DATA_DIRNAME + "train", metadata.PROCESSED_DATA_DIRNAME ) data, paragraphs_labels = generate_synthetic_paragraphs( line_crops=line_crops, line_labels=line_labels @@ -157,7 +158,7 @@ def generate_synthetic_paragraphs( paragraphs_crops, paragraphs_labels = [], [] for paragraph_indices in batched_indices_list: - paragraph_label = NEW_LINE_TOKEN.join( + paragraph_label = metadata.NEW_LINE_TOKEN.join( [line_labels[i] for i in paragraph_indices] ) if len(paragraph_label) > paragraphs_properties["label_length"]["max"]: @@ -236,8 +237,13 @@ def generate_random_batches( def create_synthetic_iam_paragraphs() -> None: """Creates and prints IAM Synthetic Paragraphs dataset.""" - transform = load_transform_from_file("transform/paragraphs.yaml") - test_transform = load_transform_from_file("test_transform/paragraphs.yaml") + transform = ParagraphStem() + test_transform = ParagraphStem() + target_transform = Pad(metadata.MAX_LABEL_LENGTH, 3) load_and_print_info( - IAMSyntheticParagraphs(transform=transform, test_transform=test_transform) + IAMSyntheticParagraphs( + transform=transform, + test_transform=test_transform, + target_transform=target_transform, + ) ) diff --git a/text_recognizer/data/stems/__init__.py b/text_recognizer/data/stems/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/text_recognizer/data/stems/image.py b/text_recognizer/data/stems/image.py new file mode 100644 index 0000000..f04b3a0 --- /dev/null +++ b/text_recognizer/data/stems/image.py @@ -0,0 +1,18 @@ +from PIL import Image +import torch +from torch import Tensor +import torchvision.transforms as T + + +class ImageStem: + def __init__(self) -> None: + self.pil_transform = T.Compose([]) + self.pil_to_tensor = T.ToTensor() + self.torch_transform = torch.nn.Sequential() + + def __call__(self, img: Image) -> Tensor: + img = self.pil_transform(img) + img = self.pil_to_tensor(img) + with torch.no_grad(): + img = self.torch_transform(img) + return img diff --git a/text_recognizer/data/stems/line.py b/text_recognizer/data/stems/line.py new file mode 100644 index 0000000..2fe1a2c --- /dev/null +++ b/text_recognizer/data/stems/line.py @@ -0,0 +1,86 @@ +import random + +from PIL import Image +import torchvision.transforms as T + +import text_recognizer.metadata.iam_lines as metadata +from text_recognizer.data.stems.image import ImageStem + + +class LineStem(ImageStem): + """A stem for handling images containing a line of text.""" + + def __init__( + self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None + ): + super().__init__() + if color_jitter_kwargs is None: + color_jitter_kwargs = {"brightness": (0.5, 1)} + if random_affine_kwargs is None: + random_affine_kwargs = { + "degrees": 3, + "translate": (0, 0.05), + "scale": (0.4, 1.1), + "shear": (-40, 50), + "interpolation": T.InterpolationMode.BILINEAR, + "fill": 0, + } + + if augment: + self.pil_transforms = T.Compose( + [ + T.ColorJitter(**color_jitter_kwargs), + T.RandomAffine(**random_affine_kwargs), + ] + ) + + +class IamLinesStem(ImageStem): + """A stem for handling images containing lines of text from the IAMLines dataset.""" + + def __init__( + self, augment=False, color_jitter_kwargs=None, random_affine_kwargs=None + ): + super().__init__() + + def embed_crop(crop, augment=augment): + # crop is PIL.image of dtype="L" (so values range from 0 -> 255) + image = Image.new("L", (metadata.IMAGE_WIDTH, metadata.IMAGE_HEIGHT)) + + # Resize crop + crop_width, crop_height = crop.size + new_crop_height = metadata.IMAGE_HEIGHT + new_crop_width = int(new_crop_height * (crop_width / crop_height)) + if augment: + # Add random stretching + new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1)) + new_crop_width = min(new_crop_width, metadata.IMAGE_WIDTH) + crop_resized = crop.resize( + (new_crop_width, new_crop_height), resample=Image.BILINEAR + ) + + # Embed in the image + x = min(metadata.CHAR_WIDTH, metadata.IMAGE_WIDTH - new_crop_width) + y = metadata.IMAGE_HEIGHT - new_crop_height + + image.paste(crop_resized, (x, y)) + + return image + + if color_jitter_kwargs is None: + color_jitter_kwargs = {"brightness": (0.8, 1.6)} + if random_affine_kwargs is None: + random_affine_kwargs = { + "degrees": 1, + "shear": (-30, 20), + "interpolation": T.InterpolationMode.BILINEAR, + "fill": 0, + } + + pil_transform_list = [T.Lambda(embed_crop)] + if augment: + pil_transform_list += [ + T.ColorJitter(**color_jitter_kwargs), + T.RandomAffine(**random_affine_kwargs), + ] + self.pil_transform = T.Compose(pil_transform_list) diff --git a/text_recognizer/data/stems/paragraph.py b/text_recognizer/data/stems/paragraph.py new file mode 100644 index 0000000..39e1e59 --- /dev/null +++ b/text_recognizer/data/stems/paragraph.py @@ -0,0 +1,66 @@ +"""Iam paragraph stem class.""" +import torchvision.transforms as T + +import text_recognizer.metadata.iam_paragraphs as metadata +from text_recognizer.data.stems.image import ImageStem + + +IMAGE_HEIGHT, IMAGE_WIDTH = metadata.IMAGE_HEIGHT, metadata.IMAGE_WIDTH +IMAGE_SHAPE = metadata.IMAGE_SHAPE + +MAX_LABEL_LENGTH = metadata.MAX_LABEL_LENGTH + + +class ParagraphStem(ImageStem): + """A stem for handling images that contain a paragraph of text.""" + + def __init__( + self, + augment=False, + color_jitter_kwargs=None, + random_affine_kwargs=None, + random_perspective_kwargs=None, + gaussian_blur_kwargs=None, + sharpness_kwargs=None, + ): + super().__init__() + + if not augment: + self.pil_transform = T.Compose([T.CenterCrop(IMAGE_SHAPE)]) + else: + if color_jitter_kwargs is None: + color_jitter_kwargs = {"brightness": 0.4, "contrast": 0.4} + if random_affine_kwargs is None: + random_affine_kwargs = { + "degrees": 3, + "shear": 6, + "scale": (0.95, 1), + "interpolation": T.InterpolationMode.BILINEAR, + } + if random_perspective_kwargs is None: + random_perspective_kwargs = { + "distortion_scale": 0.2, + "p": 0.5, + "interpolation": T.InterpolationMode.BILINEAR, + } + if gaussian_blur_kwargs is None: + gaussian_blur_kwargs = {"kernel_size": (3, 3), "sigma": (0.1, 1.0)} + if sharpness_kwargs is None: + sharpness_kwargs = {"sharpness_factor": 2, "p": 0.5} + + self.pil_transform = T.Compose( + [ + T.ColorJitter(**color_jitter_kwargs), + T.RandomCrop( + size=IMAGE_SHAPE, + padding=None, + pad_if_needed=True, + fill=0, + padding_mode="constant", + ), + T.RandomAffine(**random_affine_kwargs), + T.RandomPerspective(**random_perspective_kwargs), + T.GaussianBlur(**gaussian_blur_kwargs), + T.RandomAdjustSharpness(**sharpness_kwargs), + ] + ) -- cgit v1.2.3-70-g09d2