diff options
-rw-r--r-- | text_recognizer/data/transforms/embed_crop.py | 37 | ||||
-rw-r--r-- | text_recognizer/data/transforms/load_transform.py | 46 |
2 files changed, 0 insertions, 83 deletions
diff --git a/text_recognizer/data/transforms/embed_crop.py b/text_recognizer/data/transforms/embed_crop.py deleted file mode 100644 index 7421d0e..0000000 --- a/text_recognizer/data/transforms/embed_crop.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Transforms for PyTorch datasets.""" -import random - -from PIL import Image - - -class EmbedCrop: - - IMAGE_HEIGHT = 56 - IMAGE_WIDTH = 1024 - - def __init__(self, augment: bool) -> None: - self.augment = augment - - def __call__(self, crop: Image) -> Image: - # Crop is PIL.Image of dtype="L" (so value range is [0, 255]) - image = Image.new("L", (self.IMAGE_WIDTH, self.IMAGE_HEIGHT)) - - # Resize crop. - crop_width, crop_height = crop.size - new_crop_height = self.IMAGE_HEIGHT - new_crop_width = int(new_crop_height * (crop_width / crop_height)) - - if self.augment: - # Add random stretching - new_crop_width = int(new_crop_width * random.uniform(0.9, 1.1)) - new_crop_width = min(new_crop_width, self.IMAGE_WIDTH) - crop_resized = crop.resize( - (new_crop_width, new_crop_height), resample=Image.BILINEAR - ) - - # Embed in image - x = min(28, self.IMAGE_WIDTH - new_crop_width) - y = self.IMAGE_HEIGHT - new_crop_height - image.paste(crop_resized, (x, y)) - - return image diff --git a/text_recognizer/data/transforms/load_transform.py b/text_recognizer/data/transforms/load_transform.py deleted file mode 100644 index e8c57bc..0000000 --- a/text_recognizer/data/transforms/load_transform.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Load a config of transforms.""" -from pathlib import Path -from typing import Callable - -import torchvision.transforms as T -from hydra.utils import instantiate -from loguru import logger as log -from omegaconf import DictConfig, OmegaConf - -TRANSFORM_DIRNAME = ( - Path(__file__).resolve().parents[3] / "training" / "conf" / "datamodule" -) - - -def _load_config(filepath: str) -> DictConfig: - log.debug(f"Loading transforms from config: {filepath}") - path = TRANSFORM_DIRNAME / Path(filepath) - with open(path) as f: - cfgs = OmegaConf.load(f) - return cfgs - - -def _load_transform(transform: DictConfig) -> Callable: - """Loads a transform.""" - if "ColorJitter" in transform._target_: - return T.ColorJitter(brightness=list(transform.brightness)) - if transform.get("interpolation"): - transform.interpolation = getattr( - T.functional.InterpolationMode, transform.interpolation - ) - return instantiate(transform, _recursive_=False) - - -def load_transform_from_file(filepath: str) -> T.Compose: - """Loads transforms from a config.""" - cfgs = _load_config(filepath) - transform = load_transform(cfgs) - return transform - - -def load_transform(cfgs: DictConfig) -> T.Compose: - transforms = [] - for cfg in cfgs.values(): - transform = _load_transform(cfg) - transforms.append(transform) - return T.Compose(transforms) |