diff options
Diffstat (limited to 'training/run_experiment.py')
-rw-r--r-- | training/run_experiment.py | 66 |
1 files changed, 41 insertions, 25 deletions
diff --git a/training/run_experiment.py b/training/run_experiment.py index 1e79461..e1aae4e 100644 --- a/training/run_experiment.py +++ b/training/run_experiment.py @@ -1,4 +1,5 @@ """Script to run experiments.""" +from datetime import datetime import importlib from pathlib import Path from typing import Dict, List, Optional, Type @@ -7,7 +8,6 @@ import click from loguru import logger from omegaconf import DictConfig, OmegaConf import pytorch_lightning as pl -import torch from torch import nn from tqdm import tqdm import wandb @@ -18,6 +18,17 @@ CONFIGS_DIRNAME = Path(__file__).parent.resolve() / "configs" LOGS_DIRNAME = Path(__file__).parent.resolve() / "logs" +def _create_experiment_dir(config: DictConfig) -> Path: + """Creates log directory for experiment.""" + log_dir = ( + LOGS_DIRNAME + / f"{config.model.type}_{config.network.type}" + / datetime.now().strftime("%m%d_%H%M%S") + ) + log_dir.mkdir(parents=True, exist_ok=True) + return log_dir + + def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None: """Configure the loguru logger for output to terminal and disk.""" @@ -67,15 +78,15 @@ def _configure_callbacks( def _configure_logger( - network: Type[nn.Module], args: Dict, log_dir: str, use_wandb: bool + network: Type[nn.Module], args: Dict, log_dir: Path, use_wandb: bool ) -> Type[pl.loggers.LightningLoggerBase]: """Configures lightning logger.""" if use_wandb: - pl_logger = pl.loggers.WandbLogger(save_dir=log_dir) + pl_logger = pl.loggers.WandbLogger(save_dir=str(log_dir)) pl_logger.watch(network) pl_logger.log_hyperparams(vars(args)) return pl_logger - return pl.logger.TensorBoardLogger(save_dir=log_dir) + return pl.loggers.TensorBoardLogger(save_dir=str(log_dir)) def _save_best_weights( @@ -111,6 +122,7 @@ def _load_lit_model( def run( filename: str, + fast_dev_run: bool, train: bool, test: bool, tune: bool, @@ -118,20 +130,18 @@ def run( verbose: int = 0, ) -> None: """Runs experiment.""" - # Set log dir where logging output and weights are saved to. - log_dir = str(LOGS_DIRNAME) + # Load config. + file_path = CONFIGS_DIRNAME / filename + config = _load_config(file_path) - _configure_logging(None, verbose=verbose) + log_dir = _create_experiment_dir(config) + _configure_logging(log_dir, verbose=verbose) logger.info("Starting experiment...") # Seed everything in the experiment. logger.info(f"Seeding everthing with seed={SEED}") pl.utilities.seed.seed_everything(SEED) - # Load config. - file_path = CONFIGS_DIRNAME / filename - config = _load_config(file_path) - # Load classes. data_module_class = _import_class(f"text_recognizer.data.{config.data.type}") network_class = _import_class(f"text_recognizer.networks.{config.network.type}") @@ -152,36 +162,41 @@ def run( **config.trainer.args, callbacks=callbacks, logger=pl_logger, - weigths_save_path=log_dir, + weights_save_path=str(log_dir), ) - - if tune: - logger.info(f"Tuning learning rate and batch size...") - trainer.tune(lit_model, datamodule=data_module) - - if train: - logger.info(f"Training network...") + if fast_dev_run: + logger.info("Fast dev run...") trainer.fit(lit_model, datamodule=data_module) + else: + if tune: + logger.info("Tuning learning rate and batch size...") + trainer.tune(lit_model, datamodule=data_module) + + if train: + logger.info("Training network...") + trainer.fit(lit_model, datamodule=data_module) - if test: - logger.info(f"Testing network...") - trainer.test(lit_model, datamodule=data_module) + if test: + logger.info("Testing network...") + trainer.test(lit_model, datamodule=data_module) - _save_best_weights(callbacks, use_wandb) + _save_best_weights(callbacks, use_wandb) @click.command() @click.option("-f", "--experiment_config", type=str, help="Path to experiment config.") @click.option("--use_wandb", is_flag=True, help="If true, do use wandb for logging.") +@click.option("--dev", is_flag=True, help="If true, run a fast dev run.") @click.option( "--tune", is_flag=True, help="If true, tune hyperparameters for training." ) -@click.option("--train", is_flag=True, help="If true, train the model.") -@click.option("--test", is_flag=True, help="If true, test the model.") +@click.option("-t", "--train", is_flag=True, help="If true, train the model.") +@click.option("-e", "--test", is_flag=True, help="If true, test the model.") @click.option("-v", "--verbose", count=True) def cli( experiment_config: str, use_wandb: bool, + dev: bool, tune: bool, train: bool, test: bool, @@ -190,6 +205,7 @@ def cli( """Run experiment.""" run( filename=experiment_config, + fast_dev_run=dev, train=train, test=test, tune=tune, |