blob: cf590c1bbccee962a817069328eaab7dc39e9b55 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
|
"""Load a config of transforms."""
from pathlib import Path
from typing import Callable
from loguru import logger as log
from omegaconf import OmegaConf
from omegaconf import DictConfig
from hydra.utils import instantiate
import torchvision.transforms as T
TRANSFORM_DIRNAME = (
Path(__file__).resolve().parents[3] / "training" / "conf" / "datamodule"
)
def _load_config(filepath: str) -> DictConfig:
log.debug(f"Loading transforms from config: {filepath}")
path = TRANSFORM_DIRNAME / Path(filepath)
with open(path) as f:
cfgs = OmegaConf.load(f)
return cfgs
def _load_transform(transform: DictConfig) -> Callable:
"""Loads a transform."""
if "ColorJitter" in transform._target_:
return T.ColorJitter(brightness=list(transform.brightness))
if transform.get("interpolation"):
transform.interpolation = getattr(
T.functional.InterpolationMode, transform.interpolation
)
return instantiate(transform, _recursive_=False)
def load_transform_from_file(filepath: str) -> T.Compose:
"""Loads transforms from a config."""
cfgs = _load_config(filepath)
transform = load_transform(cfgs)
return transform
def load_transform(cfgs: DictConfig) -> T.Compose:
transforms = []
for cfg in cfgs.values():
transform = _load_transform(cfg)
transforms.append(transform)
return T.Compose(transforms)
|