summaryrefslogtreecommitdiff
path: root/text_recognizer/data/transforms/load_transform.py
blob: e8c57bc24b86bd050c08566340500fc57a31dd46 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
"""Load a config of transforms."""
from pathlib import Path
from typing import Callable

import torchvision.transforms as T
from hydra.utils import instantiate
from loguru import logger as log
from omegaconf import DictConfig, OmegaConf

TRANSFORM_DIRNAME = (
    Path(__file__).resolve().parents[3] / "training" / "conf" / "datamodule"
)


def _load_config(filepath: str) -> DictConfig:
    log.debug(f"Loading transforms from config: {filepath}")
    path = TRANSFORM_DIRNAME / Path(filepath)
    with open(path) as f:
        cfgs = OmegaConf.load(f)
    return cfgs


def _load_transform(transform: DictConfig) -> Callable:
    """Loads a transform."""
    if "ColorJitter" in transform._target_:
        return T.ColorJitter(brightness=list(transform.brightness))
    if transform.get("interpolation"):
        transform.interpolation = getattr(
            T.functional.InterpolationMode, transform.interpolation
        )
    return instantiate(transform, _recursive_=False)


def load_transform_from_file(filepath: str) -> T.Compose:
    """Loads transforms from a config."""
    cfgs = _load_config(filepath)
    transform = load_transform(cfgs)
    return transform


def load_transform(cfgs: DictConfig) -> T.Compose:
    transforms = []
    for cfg in cfgs.values():
        transform = _load_transform(cfg)
        transforms.append(transform)
    return T.Compose(transforms)