summaryrefslogtreecommitdiff
path: root/text_recognizer/data/transforms/load_transform.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:05:24 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-10 18:05:24 +0200
commit38f546f0b86fc0dc89863b00c5ee8c6685394ef2 (patch)
treedeaf5501c8e687e235add858ad92e068df7d3615 /text_recognizer/data/transforms/load_transform.py
parent8291a87c64f9a5f18caec82201bea15579b49730 (diff)
Add custom transforms
Diffstat (limited to 'text_recognizer/data/transforms/load_transform.py')
-rw-r--r--text_recognizer/data/transforms/load_transform.py47
1 files changed, 47 insertions, 0 deletions
diff --git a/text_recognizer/data/transforms/load_transform.py b/text_recognizer/data/transforms/load_transform.py
new file mode 100644
index 0000000..cf590c1
--- /dev/null
+++ b/text_recognizer/data/transforms/load_transform.py
@@ -0,0 +1,47 @@
+"""Load a config of transforms."""
+from pathlib import Path
+from typing import Callable
+
+from loguru import logger as log
+from omegaconf import OmegaConf
+from omegaconf import DictConfig
+from hydra.utils import instantiate
+import torchvision.transforms as T
+
+TRANSFORM_DIRNAME = (
+ Path(__file__).resolve().parents[3] / "training" / "conf" / "datamodule"
+)
+
+
+def _load_config(filepath: str) -> DictConfig:
+ log.debug(f"Loading transforms from config: {filepath}")
+ path = TRANSFORM_DIRNAME / Path(filepath)
+ with open(path) as f:
+ cfgs = OmegaConf.load(f)
+ return cfgs
+
+
+def _load_transform(transform: DictConfig) -> Callable:
+ """Loads a transform."""
+ if "ColorJitter" in transform._target_:
+ return T.ColorJitter(brightness=list(transform.brightness))
+ if transform.get("interpolation"):
+ transform.interpolation = getattr(
+ T.functional.InterpolationMode, transform.interpolation
+ )
+ return instantiate(transform, _recursive_=False)
+
+
+def load_transform_from_file(filepath: str) -> T.Compose:
+ """Loads transforms from a config."""
+ cfgs = _load_config(filepath)
+ transform = load_transform(cfgs)
+ return transform
+
+
+def load_transform(cfgs: DictConfig) -> T.Compose:
+ transforms = []
+ for cfg in cfgs.values():
+ transform = _load_transform(cfg)
+ transforms.append(transform)
+ return T.Compose(transforms)