summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/conf/logger/wandb.yaml15
-rw-r--r--training/main.py21
-rw-r--r--training/run.py95
-rw-r--r--training/utils.py56
4 files changed, 187 insertions, 0 deletions
diff --git a/training/conf/logger/wandb.yaml b/training/conf/logger/wandb.yaml
new file mode 100644
index 0000000..552cf00
--- /dev/null
+++ b/training/conf/logger/wandb.yaml
@@ -0,0 +1,15 @@
+# https://wandb.ai
+
+wandb:
+ _target_: pytorch_lightning.loggers.wandb.WandbLogger
+ project: "text-recognizer"
+ name: null
+ save_dir: "."
+ offline: False # set True to store all logs only locally
+ id: null # pass correct id to resume experiment!
+ # entity: "" # set to name of your wandb team or just remove it
+ log_model: False
+ prefix: ""
+ job_type: "train"
+ group: ""
+ tags: []
diff --git a/training/main.py b/training/main.py
new file mode 100644
index 0000000..26d5aeb
--- /dev/null
+++ b/training/main.py
@@ -0,0 +1,21 @@
+"""Loads config with hydra and runs experiment."""
+import hydra
+from omegaconf import DictConfig
+
+
+@hydra.main(config_path="conf", config_name="config")
+def main(config: DictConfig) -> None:
+ """Loads config with hydra and runs the experiment."""
+ import utils
+ from run import run
+
+ utils.extras(config)
+
+ if config.get("print_config"):
+ utils.print_config(config)
+
+ return run(config)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/training/run.py b/training/run.py
new file mode 100644
index 0000000..ed1b372
--- /dev/null
+++ b/training/run.py
@@ -0,0 +1,95 @@
+"""Script to run experiments."""
+from typing import List, Optional, Type
+
+import hydra
+import loguru.logger as log
+from omegaconf import DictConfig
+from pytorch_lightning import (
+ Callback,
+ LightningDataModule,
+ LightningModule,
+ Trainer,
+ seed_everything,
+)
+from pytorch_lightning.loggers import LightningLoggerBase
+from torch import nn
+
+from utils import configure_logging
+
+
+def configure_callbacks(
+ config: DictConfig,
+) -> List[Type[Callback]]:
+ """Configures lightning callbacks."""
+ callbacks = []
+ if config.get("callbacks"):
+ for callback_config in config.callbacks.values():
+ if config.get("_target_"):
+ log.info(f"Instantiating callback <{callback_config._target_}>")
+ callbacks.append(hydra.utils.instantiate(callback_config))
+ return callbacks
+
+
+def configure_logger(config: DictConfig) -> List[Type[LightningLoggerBase]]:
+ logger = []
+ if config.get("logger"):
+ for logger_config in config.logger.values():
+ if config.get("_target_"):
+ log.info(f"Instantiating callback <{logger_config._target_}>")
+ logger.append(hydra.utils.instantiate(logger_config))
+ return logger
+
+
+def run(config: DictConfig) -> Optional[float]:
+ """Runs experiment."""
+ configure_logging(config.logging)
+ log.info("Starting experiment...")
+
+ if config.get("seed"):
+ seed_everything(config.seed)
+
+ log.info(f"Instantiating datamodule <{config.datamodule._target_}>")
+ datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)
+
+ log.info(f"Instantiating network <{config.network._target_}>")
+ network: nn.Module = hydra.utils.instantiate(config.network, **datamodule.config())
+
+ log.info(f"Instantiating model <{config.model._target_}>")
+ model: LightningModule = hydra.utils.instantiate(
+ config.model,
+ network=network,
+ criterion=config.criterion,
+ optimizer=config.optimizer,
+ lr_scheduler=config.lr_scheduler,
+ _recursive_=False,
+ )
+
+ # Load callback and logger.
+ callbacks = configure_callbacks(config)
+ logger = configure_logger(config)
+
+ log.info(f"Instantiating trainer <{config.trainer._target_}>")
+ trainer: Trainer = hydra.utils.instantiate(
+ config.trainer, callbacks=callbacks, logger=logger, _convert_="partial"
+ )
+
+ # Log hyperparameters
+
+ if config.debug:
+ log.info("Fast development run...")
+ trainer.fit(model, datamodule=datamodule)
+ return None
+
+ if config.tune:
+ log.info("Tuning learning rate and batch size...")
+ trainer.tune(model, datamodule=datamodule)
+
+ if config.train:
+ log.info("Training network...")
+ trainer.fit(model, datamodule=datamodule)
+
+ if config.test:
+ log.info("Testing network...")
+ trainer.test(model, datamodule=datamodule)
+
+ # Make sure everything closes properly
diff --git a/training/utils.py b/training/utils.py
new file mode 100644
index 0000000..7717fc5
--- /dev/null
+++ b/training/utils.py
@@ -0,0 +1,56 @@
+"""Util functions for training hydra configs and pytorch lightning."""
+import warnings
+
+from omegaconf import DictConfig, OmegaConf
+import loguru.logger as log
+from pytorch_lightning.loggers.wandb import WandbLogger
+from pytorch_lightning.utilities import rank_zero_only
+from tqdm import tqdm
+
+
+@rank_zero_only
+def configure_logging(level: str) -> None:
+ """Configure the loguru logger for output to terminal and disk."""
+ # Remove default logger to get tqdm to work properly.
+ log.remove()
+ log.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level)
+
+
+def extras(config: DictConfig) -> None:
+ """Sets optional utilities."""
+ # Enable adding new keys.
+ OmegaConf.set_struct(config, False)
+
+ if config.get("ignore_warnings"):
+ log.info("Disabling python warnings! <config.ignore_warnings=True>")
+ warnings.filterwarnings("ignore")
+
+ if config.get("debug"):
+ log.info("Running in debug mode! <config.debug=True>")
+ config.trainer.fast_dev_run = True
+
+ if config.trainer.get("fast_dev_run"):
+ log.info(
+ "Forcing debugger friendly configuration! <config.trainer.fast_dev_run=True>"
+ )
+ # Debuggers do not like GPUs and multiprocessing.
+ if config.trainer.get("gpus"):
+ config.trainer.gpus = 0
+ if config.datamodule.get("pin_memory"):
+ config.datamodule.pin_memory = False
+ if config.datamodule.get("num_workers"):
+ config.datamodule.num_workers = 0
+
+ # Force multi-gpu friendly config.
+ accelerator = config.trainer.get("accelerator")
+ if accelerator in ["ddp", "ddp_spawn", "dp", "ddp2"]:
+ log.info(
+ f"Forcing ddp friendly configuration! <config.trainer.accelerator={accelerator}>"
+ )
+ if config.datamodule.get("pin_memory"):
+ config.datamodule.pin_memory = False
+ if config.datamodule.get("num_workers"):
+ config.datamodule.num_workers = 0
+
+ # Disable adding new keys to config
+ OmegaConf.set_struct(config, True)