summaryrefslogtreecommitdiff
path: root/training/run_experiment.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/run_experiment.py')
-rw-r--r--training/run_experiment.py66
1 files changed, 41 insertions, 25 deletions
diff --git a/training/run_experiment.py b/training/run_experiment.py
index 1e79461..e1aae4e 100644
--- a/training/run_experiment.py
+++ b/training/run_experiment.py
@@ -1,4 +1,5 @@
"""Script to run experiments."""
+from datetime import datetime
import importlib
from pathlib import Path
from typing import Dict, List, Optional, Type
@@ -7,7 +8,6 @@ import click
from loguru import logger
from omegaconf import DictConfig, OmegaConf
import pytorch_lightning as pl
-import torch
from torch import nn
from tqdm import tqdm
import wandb
@@ -18,6 +18,17 @@ CONFIGS_DIRNAME = Path(__file__).parent.resolve() / "configs"
LOGS_DIRNAME = Path(__file__).parent.resolve() / "logs"
+def _create_experiment_dir(config: DictConfig) -> Path:
+ """Creates log directory for experiment."""
+ log_dir = (
+ LOGS_DIRNAME
+ / f"{config.model.type}_{config.network.type}"
+ / datetime.now().strftime("%m%d_%H%M%S")
+ )
+ log_dir.mkdir(parents=True, exist_ok=True)
+ return log_dir
+
+
def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None:
"""Configure the loguru logger for output to terminal and disk."""
@@ -67,15 +78,15 @@ def _configure_callbacks(
def _configure_logger(
- network: Type[nn.Module], args: Dict, log_dir: str, use_wandb: bool
+ network: Type[nn.Module], args: Dict, log_dir: Path, use_wandb: bool
) -> Type[pl.loggers.LightningLoggerBase]:
"""Configures lightning logger."""
if use_wandb:
- pl_logger = pl.loggers.WandbLogger(save_dir=log_dir)
+ pl_logger = pl.loggers.WandbLogger(save_dir=str(log_dir))
pl_logger.watch(network)
pl_logger.log_hyperparams(vars(args))
return pl_logger
- return pl.logger.TensorBoardLogger(save_dir=log_dir)
+ return pl.loggers.TensorBoardLogger(save_dir=str(log_dir))
def _save_best_weights(
@@ -111,6 +122,7 @@ def _load_lit_model(
def run(
filename: str,
+ fast_dev_run: bool,
train: bool,
test: bool,
tune: bool,
@@ -118,20 +130,18 @@ def run(
verbose: int = 0,
) -> None:
"""Runs experiment."""
- # Set log dir where logging output and weights are saved to.
- log_dir = str(LOGS_DIRNAME)
+ # Load config.
+ file_path = CONFIGS_DIRNAME / filename
+ config = _load_config(file_path)
- _configure_logging(None, verbose=verbose)
+ log_dir = _create_experiment_dir(config)
+ _configure_logging(log_dir, verbose=verbose)
logger.info("Starting experiment...")
# Seed everything in the experiment.
logger.info(f"Seeding everthing with seed={SEED}")
pl.utilities.seed.seed_everything(SEED)
- # Load config.
- file_path = CONFIGS_DIRNAME / filename
- config = _load_config(file_path)
-
# Load classes.
data_module_class = _import_class(f"text_recognizer.data.{config.data.type}")
network_class = _import_class(f"text_recognizer.networks.{config.network.type}")
@@ -152,36 +162,41 @@ def run(
**config.trainer.args,
callbacks=callbacks,
logger=pl_logger,
- weigths_save_path=log_dir,
+ weights_save_path=str(log_dir),
)
-
- if tune:
- logger.info(f"Tuning learning rate and batch size...")
- trainer.tune(lit_model, datamodule=data_module)
-
- if train:
- logger.info(f"Training network...")
+ if fast_dev_run:
+ logger.info("Fast dev run...")
trainer.fit(lit_model, datamodule=data_module)
+ else:
+ if tune:
+ logger.info("Tuning learning rate and batch size...")
+ trainer.tune(lit_model, datamodule=data_module)
+
+ if train:
+ logger.info("Training network...")
+ trainer.fit(lit_model, datamodule=data_module)
- if test:
- logger.info(f"Testing network...")
- trainer.test(lit_model, datamodule=data_module)
+ if test:
+ logger.info("Testing network...")
+ trainer.test(lit_model, datamodule=data_module)
- _save_best_weights(callbacks, use_wandb)
+ _save_best_weights(callbacks, use_wandb)
@click.command()
@click.option("-f", "--experiment_config", type=str, help="Path to experiment config.")
@click.option("--use_wandb", is_flag=True, help="If true, do use wandb for logging.")
+@click.option("--dev", is_flag=True, help="If true, run a fast dev run.")
@click.option(
"--tune", is_flag=True, help="If true, tune hyperparameters for training."
)
-@click.option("--train", is_flag=True, help="If true, train the model.")
-@click.option("--test", is_flag=True, help="If true, test the model.")
+@click.option("-t", "--train", is_flag=True, help="If true, train the model.")
+@click.option("-e", "--test", is_flag=True, help="If true, test the model.")
@click.option("-v", "--verbose", count=True)
def cli(
experiment_config: str,
use_wandb: bool,
+ dev: bool,
tune: bool,
train: bool,
test: bool,
@@ -190,6 +205,7 @@ def cli(
"""Run experiment."""
run(
filename=experiment_config,
+ fast_dev_run=dev,
train=train,
test=test,
tune=tune,