summaryrefslogtreecommitdiff
path: root/text_recognizer/data
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data')
-rw-r--r--text_recognizer/data/transforms/embed_crop.py37
-rw-r--r--text_recognizer/data/transforms/load_transform.py46
2 files changed, 0 insertions, 83 deletions
diff --git a/text_recognizer/data/transforms/embed_crop.py b/text_recognizer/data/transforms/embed_crop.py
deleted file mode 100644
index 7421d0e..0000000
--- a/text_recognizer/data/transforms/embed_crop.py
+++ /dev/null
@@ -1,37 +0,0 @@
-"""Transforms for PyTorch datasets."""
-import random
-
-from PIL import Image
-
-
-class EmbedCrop:
-
- IMAGE_HEIGHT = 56
- IMAGE_WIDTH = 1024
-
- def __init__(self, augment: bool) -> None:
- self.augment = augment
-
- def __call__(self, crop: Image) -> Image:
- # Crop is PIL.Image of dtype="L" (so value range is [0, 255])
- image = Image.new("L", (self.IMAGE_WIDTH, self.IMAGE_HEIGHT))
-
- # Resize crop.
- crop_width, crop_height = crop.size
- new_crop_height = self.IMAGE_HEIGHT
- new_crop_width = int(new_crop_height * (crop_width / crop_height))
-
- if self.augment:
- # Add random stretching
- new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1))
- new_crop_width = min(new_crop_width, self.IMAGE_WIDTH)
- crop_resized = crop.resize(
- (new_crop_width, new_crop_height), resample=Image.BILINEAR
- )
-
- # Embed in image
- x = min(28, self.IMAGE_WIDTH - new_crop_width)
- y = self.IMAGE_HEIGHT - new_crop_height
- image.paste(crop_resized, (x, y))
-
- return image
diff --git a/text_recognizer/data/transforms/load_transform.py b/text_recognizer/data/transforms/load_transform.py
deleted file mode 100644
index e8c57bc..0000000
--- a/text_recognizer/data/transforms/load_transform.py
+++ /dev/null
@@ -1,46 +0,0 @@
-"""Load a config of transforms."""
-from pathlib import Path
-from typing import Callable
-
-import torchvision.transforms as T
-from hydra.utils import instantiate
-from loguru import logger as log
-from omegaconf import DictConfig, OmegaConf
-
-TRANSFORM_DIRNAME = (
- Path(__file__).resolve().parents[3] / "training" / "conf" / "datamodule"
-)
-
-
-def _load_config(filepath: str) -> DictConfig:
- log.debug(f"Loading transforms from config: {filepath}")
- path = TRANSFORM_DIRNAME / Path(filepath)
- with open(path) as f:
- cfgs = OmegaConf.load(f)
- return cfgs
-
-
-def _load_transform(transform: DictConfig) -> Callable:
- """Loads a transform."""
- if "ColorJitter" in transform._target_:
- return T.ColorJitter(brightness=list(transform.brightness))
- if transform.get("interpolation"):
- transform.interpolation = getattr(
- T.functional.InterpolationMode, transform.interpolation
- )
- return instantiate(transform, _recursive_=False)
-
-
-def load_transform_from_file(filepath: str) -> T.Compose:
- """Loads transforms from a config."""
- cfgs = _load_config(filepath)
- transform = load_transform(cfgs)
- return transform
-
-
-def load_transform(cfgs: DictConfig) -> T.Compose:
- transforms = []
- for cfg in cfgs.values():
- transform = _load_transform(cfg)
- transforms.append(transform)
- return T.Compose(transforms)