summaryrefslogtreecommitdiff
path: root/training/run_experiment.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-05 20:47:55 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-05 20:47:55 +0200
commit9ae5fa1a88899180f88ddb14d4cef457ceb847e5 (patch)
tree4fe2bcd82553c8062eb0908ae6442c123addf55d /training/run_experiment.py
parent9e54591b7e342edc93b0bb04809a0f54045c6a15 (diff)
Add new training loop with PyTorch Lightning, remove stale files
Diffstat (limited to 'training/run_experiment.py')
-rw-r--r--training/run_experiment.py419
1 files changed, 105 insertions, 314 deletions
diff --git a/training/run_experiment.py b/training/run_experiment.py
index faafea6..ff8b886 100644
--- a/training/run_experiment.py
+++ b/training/run_experiment.py
@@ -1,162 +1,34 @@
"""Script to run experiments."""
from datetime import datetime
-from glob import glob
import importlib
-import json
-import os
from pathlib import Path
-import re
-from typing import Callable, Dict, List, Optional, Tuple, Type
-import warnings
+from typing import Dict, List, Optional, Type
import click
from loguru import logger
import numpy as np
+from omegaconf import OmegaConf
+import pytorch_lightning as pl
import torch
+from torch import nn
from torchsummary import summary
from tqdm import tqdm
-from training.gpu_manager import GPUManager
-from training.trainer.callbacks import CallbackList
-from training.trainer.train import Trainer
import wandb
-import yaml
-import text_recognizer.models
-from text_recognizer.models import Model
-import text_recognizer.networks
-from text_recognizer.networks.loss import loss as custom_loss_module
+SEED = 4711
EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
-def _get_level(verbose: int) -> int:
- """Sets the logger level."""
- levels = {0: 40, 1: 20, 2: 10}
- verbose = verbose if verbose <= 2 else 2
- return levels[verbose]
-
-
-def _create_experiment_dir(
- experiment_config: Dict, checkpoint: Optional[str] = None
-) -> Path:
- """Create new experiment."""
- EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True)
- experiment_dir = EXPERIMENTS_DIRNAME / (
- f"{experiment_config['model']}_"
- + f"{experiment_config['dataset']['type']}_"
- + f"{experiment_config['network']['type']}"
- )
-
- if checkpoint is None:
- experiment = datetime.now().strftime("%m%d_%H%M%S")
- logger.debug(f"Creating a new experiment called {experiment}")
- else:
- available_experiments = glob(str(experiment_dir) + "/*")
- available_experiments.sort()
- if checkpoint == "last":
- experiment = available_experiments[-1]
- logger.debug(f"Resuming the latest experiment {experiment}")
- else:
- experiment = checkpoint
- if not str(experiment_dir / experiment) in available_experiments:
- raise FileNotFoundError("Experiment does not exist.")
- logger.debug(f"Resuming the from experiment {checkpoint}")
-
- experiment_dir = experiment_dir / experiment
-
- # Create log and model directories.
- log_dir = experiment_dir / "log"
- model_dir = experiment_dir / "model"
-
- return experiment_dir, log_dir, model_dir
-
-
-def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dict]:
- """Loads all modules and arguments."""
- # Load the dataset module.
- dataset_args = experiment_config.get("dataset", {})
- dataset_ = dataset_args["type"]
-
- # Import the model module and model arguments.
- model_class_ = getattr(text_recognizer.models, experiment_config["model"])
-
- # Import metrics.
- metric_fns_ = (
- {
- metric: getattr(text_recognizer.networks, metric)
- for metric in experiment_config["metrics"]
- }
- if experiment_config["metrics"] is not None
- else None
- )
-
- # Import network module and arguments.
- network_fn_ = experiment_config["network"]["type"]
- network_args = experiment_config["network"].get("args", {})
-
- # Criterion
- if experiment_config["criterion"]["type"] in custom_loss_module.__all__:
- criterion_ = getattr(custom_loss_module, experiment_config["criterion"]["type"])
- else:
- criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"])
- criterion_args = experiment_config["criterion"].get("args", {}) or {}
-
- # Optimizers
- optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"])
- optimizer_args = experiment_config["optimizer"].get("args", {})
-
- # Learning rate scheduler
- lr_scheduler_ = None
- lr_scheduler_args = None
- if "lr_scheduler" in experiment_config:
- lr_scheduler_ = getattr(
- torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"]
- )
- lr_scheduler_args = experiment_config["lr_scheduler"].get("args", {}) or {}
-
- # SWA scheduler.
- if "swa_args" in experiment_config:
- swa_args = experiment_config.get("swa_args", {}) or {}
- else:
- swa_args = None
-
- model_args = {
- "dataset": dataset_,
- "dataset_args": dataset_args,
- "metrics": metric_fns_,
- "network_fn": network_fn_,
- "network_args": network_args,
- "criterion": criterion_,
- "criterion_args": criterion_args,
- "optimizer": optimizer_,
- "optimizer_args": optimizer_args,
- "lr_scheduler": lr_scheduler_,
- "lr_scheduler_args": lr_scheduler_args,
- "swa_args": swa_args,
- }
-
- return model_class_, model_args
-
-
-def _configure_callbacks(experiment_config: Dict, model_dir: Path) -> CallbackList:
- """Configure a callback list for trainer."""
- if "Checkpoint" in experiment_config["callback_args"]:
- experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = str(
- model_dir
- )
-
- # Initializes callbacks.
- callback_modules = importlib.import_module("training.trainer.callbacks")
- callbacks = []
- for callback in experiment_config["callbacks"]:
- args = experiment_config["callback_args"][callback] or {}
- callbacks.append(getattr(callback_modules, callback)(**args))
-
- return callbacks
+def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None:
+ """Configure the loguru logger for output to terminal and disk."""
+ def _get_level(verbose: int) -> int:
+ """Sets the logger level."""
+ levels = {0: 40, 1: 20, 2: 10}
+ verbose = verbose if verbose <= 2 else 2
+ return levels[verbose]
-def _configure_logger(log_dir: Path, verbose: int = 0) -> None:
- """Configure the loguru logger for output to terminal and disk."""
# Have to remove default logger to get tqdm to work properly.
logger.remove()
@@ -164,219 +36,138 @@ def _configure_logger(log_dir: Path, verbose: int = 0) -> None:
level = _get_level(verbose)
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level)
- logger.add(
- str(log_dir / "train.log"),
- format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
- )
-
-
-def _save_config(experiment_dir: Path, experiment_config: Dict) -> None:
- """Copy config to experiment directory."""
- config_path = experiment_dir / "config.yml"
- with open(str(config_path), "w") as f:
- yaml.dump(experiment_config, f)
-
-
-def _load_from_checkpoint(
- model: Type[Model], model_dir: Path, pretrained_weights: str = None,
-) -> None:
- """If checkpoint exists, load model weights and optimizers from checkpoint."""
- # Get checkpoint path.
- if pretrained_weights is not None:
- logger.info(f"Loading weights from {pretrained_weights}.")
- checkpoint_path = (
- EXPERIMENTS_DIRNAME / Path(pretrained_weights) / "model" / "best.pt"
+ if log_dir is not None:
+ logger.add(
+ str(log_dir / "train.log"),
+ format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
)
- else:
- logger.info(f"Loading weights from {model_dir}.")
- checkpoint_path = model_dir / "last.pt"
- if checkpoint_path.exists():
- logger.info("Loading and resuming training from checkpoint.")
- model.load_from_checkpoint(checkpoint_path)
-def evaluate_embedding(model: Type[Model]) -> Dict:
- """Evaluates the embedding space."""
- from pytorch_metric_learning import testers
- from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
+def _import_class(module_and_class_name: str) -> type:
+ """Import class from module."""
+ module_name, class_name = module_and_class_name.rsplit(".", 1)
+ module = importlib.import_module(module_name)
+ return getattr(module, class_name)
- accuracy_calculator = AccuracyCalculator(
- include=("mean_average_precision_at_r",), k=10
- )
- def get_all_embeddings(model: Type[Model]) -> Tuple:
- tester = testers.BaseTester()
- return tester.get_all_embeddings(model.test_dataset, model.network)
+def _configure_pl_callbacks(args: List[Dict]) -> List[Type[pl.callbacks.Callback]]:
+ """Configures PyTorch Lightning callbacks."""
+ pl_callbacks = [
+ getattr(pl.callbacks, callback["type"])(**callback["args"]) for callback in args
+ ]
+ return pl_callbacks
- embeddings, labels = get_all_embeddings(model)
- logger.info("Computing embedding accuracy")
- accuracies = accuracy_calculator.get_accuracy(
- embeddings, embeddings, np.squeeze(labels), np.squeeze(labels), True
- )
- logger.info(
- f"Test set accuracy (MAP@10) = {accuracies['mean_average_precision_at_r']}"
- )
- return accuracies
+def _configure_wandb_callback(
+ network: Type[nn.Module], args: Dict
+) -> pl.loggers.WandbLogger:
+ """Configures wandb logger."""
+ pl_logger = pl.loggers.WandbLogger()
+ pl_logger.watch(network)
+ pl_logger.log_hyperparams(vars(args))
+ return pl_logger
-def run_experiment(
- experiment_config: Dict,
- save_weights: bool,
- device: str,
- use_wandb: bool,
- train: bool,
- test: bool,
- verbose: int = 0,
- checkpoint: Optional[str] = None,
- pretrained_weights: Optional[str] = None,
-) -> None:
- """Runs an experiment."""
- logger.info(f"Experiment config: {json.dumps(experiment_config)}")
- # Create new experiment.
- experiment_dir, log_dir, model_dir = _create_experiment_dir(
- experiment_config, checkpoint
+def _save_best_weights(
+ callbacks: List[Type[pl.callbacks.Callback]], use_wandb: bool
+) -> None:
+ """Saves the best model."""
+ model_checkpoint_callback = next(
+ callback
+ for callback in callbacks
+ if isinstance(callback, pl.callbacks.ModelCheckpoint)
)
+ best_model_path = model_checkpoint_callback.best_model_path
+ if best_model_path:
+ logger.info(f"Best model saved at: {best_model_path}")
+ if use_wandb:
+ logger.info("Uploading model to W&B...")
+ wandb.save(best_model_path)
- # Make sure the log/model directory exists.
- log_dir.mkdir(parents=True, exist_ok=True)
- model_dir.mkdir(parents=True, exist_ok=True)
-
- # Load the modules and model arguments.
- model_class_, model_args = _load_modules_and_arguments(experiment_config)
-
- # Initializes the model with experiment config.
- model = model_class_(**model_args, device=device)
-
- callbacks = _configure_callbacks(experiment_config, model_dir)
- # Setup logger.
- _configure_logger(log_dir, verbose)
+def run(path: str, train: bool, test: bool, tune: bool, use_wandb: bool) -> None:
+ """Runs experiment."""
+ logger.info("Starting experiment...")
- # Load from checkpoint if resuming an experiment.
- resume = False
- if checkpoint is not None or pretrained_weights is not None:
- # resume = True
- _load_from_checkpoint(model, model_dir, pretrained_weights)
+ # Seed everything in the experiment
+ logger.info(f"Seeding everthing with seed={SEED}")
+ pl.utilities.seed.seed_everything(SEED)
- logger.info(f"The class mapping is {model.mapping}")
+ # Load config.
+ logger.info(f"Loading config from: {path}")
+ config = OmegaConf.load(path)
- # Initializes Weights & Biases
- if use_wandb:
- wandb.init(project="text-recognizer", config=experiment_config, resume=resume)
+ # Load classes
+ data_module_class = _import_class(f"text_recognizer.data.{config.data.type}")
+ network_class = _import_class(f"text_recognizer.networks.{config.network.type}")
+ lit_model_class = _import_class(f"text_recognizer.models.{config.model.type}")
- # Lets W&B save the model and track the gradients and optional parameters.
- wandb.watch(model.network)
+ # Initialize data object and network.
+ data_module = data_module_class(**config.data.args)
+ network = network_class(**config.network.args)
- experiment_config["experiment_group"] = experiment_config.get(
- "experiment_group", None
+ # Load callback and logger
+ callbacks = _configure_pl_callbacks(config.callbacks)
+ pl_logger = (
+ _configure_wandb_callback(network, config.network.args)
+ if use_wandb
+ else pl.logger.TensorBoardLogger("training/logs")
)
- experiment_config["device"] = device
-
- # Save the config used in the experiment folder.
- _save_config(experiment_dir, experiment_config)
-
- # Prints a summary of the network in terminal.
- model.summary(experiment_config["train_args"]["input_shape"])
+ # Checkpoint
+ if config.load_checkpoint is not None:
+ logger.info(
+ f"Loading network weights from checkpoint: {config.load_checkpoint}"
+ )
+ lit_model = lit_model_class.load_from_checkpoint(
+ config.load_checkpoint, network=network, **config.model.args
+ )
+ else:
+ lit_model = lit_model_class(**config.model.args)
- # Load trainer.
- trainer = Trainer(
- max_epochs=experiment_config["train_args"]["max_epochs"],
+ trainer = pl.Trainer(
+ **config.trainer,
callbacks=callbacks,
- transformer_model=experiment_config["train_args"]["transformer_model"],
- max_norm=experiment_config["train_args"]["max_norm"],
- freeze_backbone=experiment_config["train_args"]["freeze_backbone"],
+ logger=pl_logger,
+ weigths_save_path="training/logs",
)
- # Train the model.
+ if tune:
+ logger.info(f"Tuning learning rate and batch size...")
+ trainer.tune(lit_model, datamodule=data_module)
+
if train:
- trainer.fit(model)
+ logger.info(f"Training network...")
+ trainer.fit(lit_model, datamodule=data_module)
- # Run inference over test set.
if test:
- logger.info("Loading checkpoint with the best weights.")
- if "checkpoint" in experiment_config["train_args"]:
- model.load_from_checkpoint(
- model_dir / experiment_config["train_args"]["checkpoint"]
- )
- else:
- model.load_from_checkpoint(model_dir / "best.pt")
-
- logger.info("Running inference on test set.")
- if experiment_config["criterion"]["type"] == "EmbeddingLoss":
- logger.info("Evaluating embedding.")
- score = evaluate_embedding(model)
- else:
- score = trainer.test(model)
-
- logger.info(f"Test set evaluation: {score}")
-
- if use_wandb:
- wandb.log(
- {
- experiment_config["test_metric"]: score[
- experiment_config["test_metric"]
- ]
- }
- )
+ logger.info(f"Testing network...")
+ trainer.test(lit_model, datamodule=data_module)
- if save_weights:
- model.save_weights(model_dir)
+ _save_best_weights(callbacks, use_wandb)
@click.command()
-@click.argument("experiment_config",)
-@click.option("--gpu", type=int, default=0, help="Provide the index of the GPU to use.")
+@click.option("-f", "--experiment_config", type=str, help="Path to experiment config.")
+@click.option("--use_wandb", is_flag=True, help="If true, do use wandb for logging.")
@click.option(
- "--save",
- is_flag=True,
- help="If set, the final weights will be saved to a canonical, version-controlled location.",
-)
-@click.option(
- "--nowandb", is_flag=False, help="If true, do not use wandb for this run."
+ "--tune", is_flag=True, help="If true, tune hyperparameters for training."
)
+@click.option("--train", is_flag=True, help="If true, train the model.")
@click.option("--test", is_flag=True, help="If true, test the model.")
@click.option("-v", "--verbose", count=True)
-@click.option("--checkpoint", type=str, help="Path to the experiment.")
-@click.option(
- "--pretrained_weights", type=str, help="Path to pretrained model weights."
-)
-@click.option(
- "--notrain", is_flag=False, help="Do not train the model.",
-)
-def run_cli(
+def cli(
experiment_config: str,
- gpu: int,
- save: bool,
- nowandb: bool,
- notrain: bool,
+ use_wandb: bool,
+ tune: bool,
+ train: bool,
test: bool,
verbose: int,
- checkpoint: Optional[str] = None,
- pretrained_weights: Optional[str] = None,
) -> None:
"""Run experiment."""
- if gpu < 0:
- gpu_manager = GPUManager(True)
- gpu = gpu_manager.get_free_gpu()
- device = "cuda:" + str(gpu)
-
- experiment_config = json.loads(experiment_config)
- os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu}"
-
- run_experiment(
- experiment_config,
- save,
- device,
- use_wandb=not nowandb,
- train=not notrain,
- test=test,
- verbose=verbose,
- checkpoint=checkpoint,
- pretrained_weights=pretrained_weights,
- )
+ _configure_logging(None, verbose=verbose)
+ run(path=experiment_config, train=train, test=test, tune=tune, use_wandb=use_wandb)
if __name__ == "__main__":
- run_cli()
+ cli()