diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:05:24 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-10 18:05:24 +0200 |
commit | 38f546f0b86fc0dc89863b00c5ee8c6685394ef2 (patch) | |
tree | deaf5501c8e687e235add858ad92e068df7d3615 /text_recognizer/data/transforms/load_transform.py | |
parent | 8291a87c64f9a5f18caec82201bea15579b49730 (diff) |
Add custom transforms
Diffstat (limited to 'text_recognizer/data/transforms/load_transform.py')
-rw-r--r-- | text_recognizer/data/transforms/load_transform.py | 47 |
1 files changed, 47 insertions, 0 deletions
diff --git a/text_recognizer/data/transforms/load_transform.py b/text_recognizer/data/transforms/load_transform.py new file mode 100644 index 0000000..cf590c1 --- /dev/null +++ b/text_recognizer/data/transforms/load_transform.py @@ -0,0 +1,47 @@ +"""Load a config of transforms.""" +from pathlib import Path +from typing import Callable + +from loguru import logger as log +from omegaconf import OmegaConf +from omegaconf import DictConfig +from hydra.utils import instantiate +import torchvision.transforms as T + +TRANSFORM_DIRNAME = ( + Path(__file__).resolve().parents[3] / "training" / "conf" / "datamodule" +) + + +def _load_config(filepath: str) -> DictConfig: + log.debug(f"Loading transforms from config: {filepath}") + path = TRANSFORM_DIRNAME / Path(filepath) + with open(path) as f: + cfgs = OmegaConf.load(f) + return cfgs + + +def _load_transform(transform: DictConfig) -> Callable: + """Loads a transform.""" + if "ColorJitter" in transform._target_: + return T.ColorJitter(brightness=list(transform.brightness)) + if transform.get("interpolation"): + transform.interpolation = getattr( + T.functional.InterpolationMode, transform.interpolation + ) + return instantiate(transform, _recursive_=False) + + +def load_transform_from_file(filepath: str) -> T.Compose: + """Loads transforms from a config.""" + cfgs = _load_config(filepath) + transform = load_transform(cfgs) + return transform + + +def load_transform(cfgs: DictConfig) -> T.Compose: + transforms = [] + for cfg in cfgs.values(): + transform = _load_transform(cfg) + transforms.append(transform) + return T.Compose(transforms) |