summaryrefslogtreecommitdiff
path: root/text_recognizer/data/transforms/load_transform.py
blob: cf590c1bbccee962a817069328eaab7dc39e9b55 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
"""Load a config of transforms."""
from pathlib import Path
from typing import Callable

from loguru import logger as log
from omegaconf import OmegaConf
from omegaconf import DictConfig
from hydra.utils import instantiate
import torchvision.transforms as T

TRANSFORM_DIRNAME = (
    Path(__file__).resolve().parents[3] / "training" / "conf" / "datamodule"
)


def _load_config(filepath: str) -> DictConfig:
    log.debug(f"Loading transforms from config: {filepath}")
    path = TRANSFORM_DIRNAME / Path(filepath)
    with open(path) as f:
        cfgs = OmegaConf.load(f)
    return cfgs


def _load_transform(transform: DictConfig) -> Callable:
    """Loads a transform."""
    if "ColorJitter" in transform._target_:
        return T.ColorJitter(brightness=list(transform.brightness))
    if transform.get("interpolation"):
        transform.interpolation = getattr(
            T.functional.InterpolationMode, transform.interpolation
        )
    return instantiate(transform, _recursive_=False)


def load_transform_from_file(filepath: str) -> T.Compose:
    """Loads transforms from a config."""
    cfgs = _load_config(filepath)
    transform = load_transform(cfgs)
    return transform


def load_transform(cfgs: DictConfig) -> T.Compose:
    transforms = []
    for cfg in cfgs.values():
        transform = _load_transform(cfg)
        transforms.append(transform)
    return T.Compose(transforms)