summaryrefslogtreecommitdiff
path: root/text_recognizer/data
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:05:24 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:05:24 +0200
commit38f546f0b86fc0dc89863b00c5ee8c6685394ef2 (patch)
treedeaf5501c8e687e235add858ad92e068df7d3615 /text_recognizer/data
parent8291a87c64f9a5f18caec82201bea15579b49730 (diff)
Add custom transforms
Diffstat (limited to 'text_recognizer/data')
-rw-r--r--text_recognizer/data/transforms/__init__.py1
-rw-r--r--text_recognizer/data/transforms/barlow.py19
-rw-r--r--text_recognizer/data/transforms/embed_crop.py37
-rw-r--r--text_recognizer/data/transforms/load_transform.py47
4 files changed, 104 insertions, 0 deletions
diff --git a/text_recognizer/data/transforms/__init__.py b/text_recognizer/data/transforms/__init__.py
new file mode 100644
index 0000000..7d521a5
--- /dev/null
+++ b/text_recognizer/data/transforms/__init__.py
@@ -0,0 +1 @@
+"""Dataset transforms."""
diff --git a/text_recognizer/data/transforms/barlow.py b/text_recognizer/data/transforms/barlow.py
new file mode 100644
index 0000000..78683cb
--- /dev/null
+++ b/text_recognizer/data/transforms/barlow.py
@@ -0,0 +1,19 @@
+"""Augmentations for training Barlow Twins."""
+from omegaconf.dictconfig import DictConfig
+from torch import Tensor
+
+from text_recognizer.data.transforms.load_transform import load_transform
+
+
+class BarlowTransform:
+ """Applies two different transforms to input data."""
+
+ def __init__(self, prim: DictConfig, bis: DictConfig) -> None:
+ self.prim = load_transform(prim)
+ self.bis = load_transform(bis)
+
+ def __call__(self, data: Tensor) -> Tensor:
+ """Applies two different augmentation on the input."""
+ x_prim = self.prim(data)
+ x_bis = self.bis(data)
+ return x_prim, x_bis
diff --git a/text_recognizer/data/transforms/embed_crop.py b/text_recognizer/data/transforms/embed_crop.py
new file mode 100644
index 0000000..7421d0e
--- /dev/null
+++ b/text_recognizer/data/transforms/embed_crop.py
@@ -0,0 +1,37 @@
+"""Transforms for PyTorch datasets."""
+import random
+
+from PIL import Image
+
+
+class EmbedCrop:
+
+ IMAGE_HEIGHT = 56
+ IMAGE_WIDTH = 1024
+
+ def __init__(self, augment: bool) -> None:
+ self.augment = augment
+
+ def __call__(self, crop: Image) -> Image:
+ # Crop is PIL.Image of dtype="L" (so value range is [0, 255])
+ image = Image.new("L", (self.IMAGE_WIDTH, self.IMAGE_HEIGHT))
+
+ # Resize crop.
+ crop_width, crop_height = crop.size
+ new_crop_height = self.IMAGE_HEIGHT
+ new_crop_width = int(new_crop_height * (crop_width / crop_height))
+
+ if self.augment:
+ # Add random stretching
+ new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1))
+ new_crop_width = min(new_crop_width, self.IMAGE_WIDTH)
+ crop_resized = crop.resize(
+ (new_crop_width, new_crop_height), resample=Image.BILINEAR
+ )
+
+ # Embed in image
+ x = min(28, self.IMAGE_WIDTH - new_crop_width)
+ y = self.IMAGE_HEIGHT - new_crop_height
+ image.paste(crop_resized, (x, y))
+
+ return image
diff --git a/text_recognizer/data/transforms/load_transform.py b/text_recognizer/data/transforms/load_transform.py
new file mode 100644
index 0000000..cf590c1
--- /dev/null
+++ b/text_recognizer/data/transforms/load_transform.py
@@ -0,0 +1,47 @@
+"""Load a config of transforms."""
+from pathlib import Path
+from typing import Callable
+
+from loguru import logger as log
+from omegaconf import OmegaConf
+from omegaconf import DictConfig
+from hydra.utils import instantiate
+import torchvision.transforms as T
+
+TRANSFORM_DIRNAME = (
+ Path(__file__).resolve().parents[3] / "training" / "conf" / "datamodule"
+)
+
+
+def _load_config(filepath: str) -> DictConfig:
+ log.debug(f"Loading transforms from config: {filepath}")
+ path = TRANSFORM_DIRNAME / Path(filepath)
+ with open(path) as f:
+ cfgs = OmegaConf.load(f)
+ return cfgs
+
+
+def _load_transform(transform: DictConfig) -> Callable:
+ """Loads a transform."""
+ if "ColorJitter" in transform._target_:
+ return T.ColorJitter(brightness=list(transform.brightness))
+ if transform.get("interpolation"):
+ transform.interpolation = getattr(
+ T.functional.InterpolationMode, transform.interpolation
+ )
+ return instantiate(transform, _recursive_=False)
+
+
+def load_transform_from_file(filepath: str) -> T.Compose:
+ """Loads transforms from a config."""
+ cfgs = _load_config(filepath)
+ transform = load_transform(cfgs)
+ return transform
+
+
+def load_transform(cfgs: DictConfig) -> T.Compose:
+ transforms = []
+ for cfg in cfgs.values():
+ transform = _load_transform(cfg)
+ transforms.append(transform)
+ return T.Compose(transforms)