summaryrefslogtreecommitdiff
path: root/text_recognizer/data/transforms/load_transform.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/data/transforms/load_transform.py')
-rw-r--r--text_recognizer/data/transforms/load_transform.py46
1 files changed, 0 insertions, 46 deletions
diff --git a/text_recognizer/data/transforms/load_transform.py b/text_recognizer/data/transforms/load_transform.py
deleted file mode 100644
index e8c57bc..0000000
--- a/text_recognizer/data/transforms/load_transform.py
+++ /dev/null
@@ -1,46 +0,0 @@
-"""Load a config of transforms."""
-from pathlib import Path
-from typing import Callable
-
-import torchvision.transforms as T
-from hydra.utils import instantiate
-from loguru import logger as log
-from omegaconf import DictConfig, OmegaConf
-
-TRANSFORM_DIRNAME = (
- Path(__file__).resolve().parents[3] / "training" / "conf" / "datamodule"
-)
-
-
-def _load_config(filepath: str) -> DictConfig:
- log.debug(f"Loading transforms from config: {filepath}")
- path = TRANSFORM_DIRNAME / Path(filepath)
- with open(path) as f:
- cfgs = OmegaConf.load(f)
- return cfgs
-
-
-def _load_transform(transform: DictConfig) -> Callable:
- """Loads a transform."""
- if "ColorJitter" in transform._target_:
- return T.ColorJitter(brightness=list(transform.brightness))
- if transform.get("interpolation"):
- transform.interpolation = getattr(
- T.functional.InterpolationMode, transform.interpolation
- )
- return instantiate(transform, _recursive_=False)
-
-
-def load_transform_from_file(filepath: str) -> T.Compose:
- """Loads transforms from a config."""
- cfgs = _load_config(filepath)
- transform = load_transform(cfgs)
- return transform
-
-
-def load_transform(cfgs: DictConfig) -> T.Compose:
- transforms = []
- for cfg in cfgs.values():
- transform = _load_transform(cfg)
- transforms.append(transform)
- return T.Compose(transforms)