blob: e8c57bc24b86bd050c08566340500fc57a31dd46 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
|
"""Load a config of transforms."""
from pathlib import Path
from typing import Callable
import torchvision.transforms as T
from hydra.utils import instantiate
from loguru import logger as log
from omegaconf import DictConfig, OmegaConf
TRANSFORM_DIRNAME = (
Path(__file__).resolve().parents[3] / "training" / "conf" / "datamodule"
)
def _load_config(filepath: str) -> DictConfig:
log.debug(f"Loading transforms from config: {filepath}")
path = TRANSFORM_DIRNAME / Path(filepath)
with open(path) as f:
cfgs = OmegaConf.load(f)
return cfgs
def _load_transform(transform: DictConfig) -> Callable:
"""Loads a transform."""
if "ColorJitter" in transform._target_:
return T.ColorJitter(brightness=list(transform.brightness))
if transform.get("interpolation"):
transform.interpolation = getattr(
T.functional.InterpolationMode, transform.interpolation
)
return instantiate(transform, _recursive_=False)
def load_transform_from_file(filepath: str) -> T.Compose:
"""Loads transforms from a config."""
cfgs = _load_config(filepath)
transform = load_transform(cfgs)
return transform
def load_transform(cfgs: DictConfig) -> T.Compose:
transforms = []
for cfg in cfgs.values():
transform = _load_transform(cfg)
transforms.append(transform)
return T.Compose(transforms)
|