From 38f546f0b86fc0dc89863b00c5ee8c6685394ef2 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 10 Oct 2021 18:05:24 +0200 Subject: Add custom transforms --- text_recognizer/data/transforms/__init__.py | 1 + text_recognizer/data/transforms/barlow.py | 19 +++++++++ text_recognizer/data/transforms/embed_crop.py | 37 ++++++++++++++++++ text_recognizer/data/transforms/load_transform.py | 47 +++++++++++++++++++++++ 4 files changed, 104 insertions(+) create mode 100644 text_recognizer/data/transforms/__init__.py create mode 100644 text_recognizer/data/transforms/barlow.py create mode 100644 text_recognizer/data/transforms/embed_crop.py create mode 100644 text_recognizer/data/transforms/load_transform.py (limited to 'text_recognizer/data/transforms') diff --git a/text_recognizer/data/transforms/__init__.py b/text_recognizer/data/transforms/__init__.py new file mode 100644 index 0000000..7d521a5 --- /dev/null +++ b/text_recognizer/data/transforms/__init__.py @@ -0,0 +1 @@ +"""Dataset transforms.""" diff --git a/text_recognizer/data/transforms/barlow.py b/text_recognizer/data/transforms/barlow.py new file mode 100644 index 0000000..78683cb --- /dev/null +++ b/text_recognizer/data/transforms/barlow.py @@ -0,0 +1,19 @@ +"""Augmentations for training Barlow Twins.""" +from omegaconf.dictconfig import DictConfig +from torch import Tensor + +from text_recognizer.data.transforms.load_transform import load_transform + + +class BarlowTransform: + """Applies two different transforms to input data.""" + + def __init__(self, prim: DictConfig, bis: DictConfig) -> None: + self.prim = load_transform(prim) + self.bis = load_transform(bis) + + def __call__(self, data: Tensor) -> Tensor: + """Applies two different augmentation on the input.""" + x_prim = self.prim(data) + x_bis = self.bis(data) + return x_prim, x_bis diff --git a/text_recognizer/data/transforms/embed_crop.py b/text_recognizer/data/transforms/embed_crop.py new file mode 100644 index 0000000..7421d0e --- /dev/null +++ b/text_recognizer/data/transforms/embed_crop.py @@ -0,0 +1,37 @@ +"""Transforms for PyTorch datasets.""" +import random + +from PIL import Image + + +class EmbedCrop: + + IMAGE_HEIGHT = 56 + IMAGE_WIDTH = 1024 + + def __init__(self, augment: bool) -> None: + self.augment = augment + + def __call__(self, crop: Image) -> Image: + # Crop is PIL.Image of dtype="L" (so value range is [0, 255]) + image = Image.new("L", (self.IMAGE_WIDTH, self.IMAGE_HEIGHT)) + + # Resize crop. + crop_width, crop_height = crop.size + new_crop_height = self.IMAGE_HEIGHT + new_crop_width = int(new_crop_height * (crop_width / crop_height)) + + if self.augment: + # Add random stretching + new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1)) + new_crop_width = min(new_crop_width, self.IMAGE_WIDTH) + crop_resized = crop.resize( + (new_crop_width, new_crop_height), resample=Image.BILINEAR + ) + + # Embed in image + x = min(28, self.IMAGE_WIDTH - new_crop_width) + y = self.IMAGE_HEIGHT - new_crop_height + image.paste(crop_resized, (x, y)) + + return image diff --git a/text_recognizer/data/transforms/load_transform.py b/text_recognizer/data/transforms/load_transform.py new file mode 100644 index 0000000..cf590c1 --- /dev/null +++ b/text_recognizer/data/transforms/load_transform.py @@ -0,0 +1,47 @@ +"""Load a config of transforms.""" +from pathlib import Path +from typing import Callable + +from loguru import logger as log +from omegaconf import OmegaConf +from omegaconf import DictConfig +from hydra.utils import instantiate +import torchvision.transforms as T + +TRANSFORM_DIRNAME = ( + Path(__file__).resolve().parents[3] / "training" / "conf" / "datamodule" +) + + +def _load_config(filepath: str) -> DictConfig: + log.debug(f"Loading transforms from config: {filepath}") + path = TRANSFORM_DIRNAME / Path(filepath) + with open(path) as f: + cfgs = OmegaConf.load(f) + return cfgs + + +def _load_transform(transform: DictConfig) -> Callable: + """Loads a transform.""" + if "ColorJitter" in transform._target_: + return T.ColorJitter(brightness=list(transform.brightness)) + if transform.get("interpolation"): + transform.interpolation = getattr( + T.functional.InterpolationMode, transform.interpolation + ) + return instantiate(transform, _recursive_=False) + + +def load_transform_from_file(filepath: str) -> T.Compose: + """Loads transforms from a config.""" + cfgs = _load_config(filepath) + transform = load_transform(cfgs) + return transform + + +def load_transform(cfgs: DictConfig) -> T.Compose: + transforms = [] + for cfg in cfgs.values(): + transform = _load_transform(cfg) + transforms.append(transform) + return T.Compose(transforms) -- cgit v1.2.3-70-g09d2