diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-06 22:31:54 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-06 22:31:54 +0200 |
commit | 31d58f2108165802d26eb1c1bdb9e5f052b4dd26 (patch) | |
tree | 6f5c2dcb0eef814c71a34df98444be7e8f1d0b43 /training/run_experiment.py | |
parent | 5e11924ca6aaea7898caca94675f41f67706a406 (diff) |
Fix network args
Diffstat (limited to 'training/run_experiment.py')
-rw-r--r-- | training/run_experiment.py | 36 |
1 files changed, 17 insertions, 19 deletions
diff --git a/training/run_experiment.py b/training/run_experiment.py index 0a67bfa..ea9f512 100644 --- a/training/run_experiment.py +++ b/training/run_experiment.py @@ -6,7 +6,6 @@ from typing import Dict, List, NamedTuple, Optional, Union, Type import click from loguru import logger -import numpy as np from omegaconf import OmegaConf import pytorch_lightning as pl import torch @@ -23,10 +22,10 @@ EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None: """Configure the loguru logger for output to terminal and disk.""" - def _get_level(verbose: int) -> int: + def _get_level(verbose: int) -> str: """Sets the logger level.""" - levels = {0: 40, 1: 20, 2: 10} - verbose = verbose if verbose <= 2 else 2 + levels = {0: "WARNING", 1: "INFO", 2: "DEBUG"} + verbose = min(verbose, 2) return levels[verbose] # Have to remove default logger to get tqdm to work properly. @@ -50,7 +49,7 @@ def _import_class(module_and_class_name: str) -> type: return getattr(module, class_name) -def _configure_pl_callbacks( +def _configure_callbacks( args: List[Union[OmegaConf, NamedTuple]] ) -> List[Type[pl.callbacks.Callback]]: """Configures PyTorch Lightning callbacks.""" @@ -60,13 +59,16 @@ def _configure_pl_callbacks( return pl_callbacks -def _configure_wandb_callback( - network: Type[nn.Module], args: Dict +def _configure_logger( + network: Type[nn.Module], args: Dict, use_wandb: bool ) -> pl.loggers.WandbLogger: - """Configures wandb logger.""" - pl_logger = pl.loggers.WandbLogger() - pl_logger.watch(network) - pl_logger.log_hyperparams(vars(args)) + """Configures PyTorch Lightning logger.""" + if use_wandb: + pl_logger = pl.loggers.WandbLogger() + pl_logger.watch(network) + pl_logger.log_hyperparams(vars(args)) + else: + pl_logger = pl.logger.TensorBoardLogger("training/logs") return pl_logger @@ -106,15 +108,11 @@ def run(path: str, train: bool, test: bool, tune: bool, use_wandb: bool) -> None # Initialize data object and network. data_module = data_module_class(**config.data.args) - network = network_class(**config.network.args) + network = network_class(**data_module.config(), **config.network.args) # Load callback and logger - callbacks = _configure_pl_callbacks(config.callbacks) - pl_logger = ( - _configure_wandb_callback(network, config.network.args) - if use_wandb - else pl.logger.TensorBoardLogger("training/logs") - ) + callbacks = _configure_callbacks(config.callbacks) + pl_logger = _configure_logger(network, config, use_wandb) # Checkpoint if config.load_checkpoint is not None: @@ -125,7 +123,7 @@ def run(path: str, train: bool, test: bool, tune: bool, use_wandb: bool) -> None config.load_checkpoint, network=network, **config.model.args ) else: - lit_model = lit_model_class(**config.model.args) + lit_model = lit_model_class(network=network, **config.model.args) trainer = pl.Trainer( **config.trainer, |