summaryrefslogtreecommitdiff
path: root/training/run_experiment.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-06 22:31:54 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-06 22:31:54 +0200
commit31d58f2108165802d26eb1c1bdb9e5f052b4dd26 (patch)
tree6f5c2dcb0eef814c71a34df98444be7e8f1d0b43 /training/run_experiment.py
parent5e11924ca6aaea7898caca94675f41f67706a406 (diff)
Fix network args
Diffstat (limited to 'training/run_experiment.py')
-rw-r--r--training/run_experiment.py36
1 files changed, 17 insertions, 19 deletions
diff --git a/training/run_experiment.py b/training/run_experiment.py
index 0a67bfa..ea9f512 100644
--- a/training/run_experiment.py
+++ b/training/run_experiment.py
@@ -6,7 +6,6 @@ from typing import Dict, List, NamedTuple, Optional, Union, Type
import click
from loguru import logger
-import numpy as np
from omegaconf import OmegaConf
import pytorch_lightning as pl
import torch
@@ -23,10 +22,10 @@ EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None:
"""Configure the loguru logger for output to terminal and disk."""
- def _get_level(verbose: int) -> int:
+ def _get_level(verbose: int) -> str:
"""Sets the logger level."""
- levels = {0: 40, 1: 20, 2: 10}
- verbose = verbose if verbose <= 2 else 2
+ levels = {0: "WARNING", 1: "INFO", 2: "DEBUG"}
+ verbose = min(verbose, 2)
return levels[verbose]
# Have to remove default logger to get tqdm to work properly.
@@ -50,7 +49,7 @@ def _import_class(module_and_class_name: str) -> type:
return getattr(module, class_name)
-def _configure_pl_callbacks(
+def _configure_callbacks(
args: List[Union[OmegaConf, NamedTuple]]
) -> List[Type[pl.callbacks.Callback]]:
"""Configures PyTorch Lightning callbacks."""
@@ -60,13 +59,16 @@ def _configure_pl_callbacks(
return pl_callbacks
-def _configure_wandb_callback(
- network: Type[nn.Module], args: Dict
+def _configure_logger(
+ network: Type[nn.Module], args: Dict, use_wandb: bool
) -> pl.loggers.WandbLogger:
- """Configures wandb logger."""
- pl_logger = pl.loggers.WandbLogger()
- pl_logger.watch(network)
- pl_logger.log_hyperparams(vars(args))
+ """Configures PyTorch Lightning logger."""
+ if use_wandb:
+ pl_logger = pl.loggers.WandbLogger()
+ pl_logger.watch(network)
+ pl_logger.log_hyperparams(vars(args))
+ else:
+ pl_logger = pl.logger.TensorBoardLogger("training/logs")
return pl_logger
@@ -106,15 +108,11 @@ def run(path: str, train: bool, test: bool, tune: bool, use_wandb: bool) -> None
# Initialize data object and network.
data_module = data_module_class(**config.data.args)
- network = network_class(**config.network.args)
+ network = network_class(**data_module.config(), **config.network.args)
# Load callback and logger
- callbacks = _configure_pl_callbacks(config.callbacks)
- pl_logger = (
- _configure_wandb_callback(network, config.network.args)
- if use_wandb
- else pl.logger.TensorBoardLogger("training/logs")
- )
+ callbacks = _configure_callbacks(config.callbacks)
+ pl_logger = _configure_logger(network, config, use_wandb)
# Checkpoint
if config.load_checkpoint is not None:
@@ -125,7 +123,7 @@ def run(path: str, train: bool, test: bool, tune: bool, use_wandb: bool) -> None
config.load_checkpoint, network=network, **config.model.args
)
else:
- lit_model = lit_model_class(**config.model.args)
+ lit_model = lit_model_class(network=network, **config.model.args)
trainer = pl.Trainer(
**config.trainer,