summaryrefslogtreecommitdiff
path: root/src/training/run_experiment.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 14:54:44 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 14:54:44 +0100
commitdc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (patch)
tree1b5fc0d06952e13727e85c4f973a26d277068453 /src/training/run_experiment.py
parente181195a699d7fa237f256d90ab4dedffc03d405 (diff)
new updates
Diffstat (limited to 'src/training/run_experiment.py')
-rw-r--r--src/training/run_experiment.py161
1 files changed, 111 insertions, 50 deletions
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index a347d9f..0510d5c 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -6,12 +6,15 @@ import json
import os
from pathlib import Path
import re
-from typing import Callable, Dict, List, Tuple, Type
+from typing import Callable, Dict, List, Optional, Tuple, Type
+import warnings
+import adabelief_pytorch
import click
from loguru import logger
import numpy as np
import torch
+from torchsummary import summary
from tqdm import tqdm
from training.gpu_manager import GPUManager
from training.trainer.callbacks import Callback, CallbackList
@@ -21,26 +24,23 @@ import yaml
from text_recognizer.models import Model
-from text_recognizer.networks import losses
-
+from text_recognizer.networks import loss as custom_loss_module
EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
-CUSTOM_LOSSES = ["EmbeddingLoss"]
DEFAULT_TRAIN_ARGS = {"batch_size": 64, "epochs": 16}
-def get_level(experiment_config: Dict) -> int:
+def _get_level(verbose: int) -> int:
"""Sets the logger level."""
- if experiment_config["verbosity"] == 0:
- return 40
- elif experiment_config["verbosity"] == 1:
- return 20
- else:
- return 10
+ levels = {0: 40, 1: 20, 2: 10}
+ verbose = verbose if verbose <= 2 else 2
+ return levels[verbose]
-def create_experiment_dir(experiment_config: Dict) -> Path:
+def _create_experiment_dir(
+ experiment_config: Dict, checkpoint: Optional[str] = None
+) -> Path:
"""Create new experiment."""
EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True)
experiment_dir = EXPERIMENTS_DIRNAME / (
@@ -48,19 +48,21 @@ def create_experiment_dir(experiment_config: Dict) -> Path:
+ f"{experiment_config['dataset']['type']}_"
+ f"{experiment_config['network']['type']}"
)
- if experiment_config["resume_experiment"] is None:
+
+ if checkpoint is None:
experiment = datetime.now().strftime("%m%d_%H%M%S")
logger.debug(f"Creating a new experiment called {experiment}")
else:
available_experiments = glob(str(experiment_dir) + "/*")
available_experiments.sort()
- if experiment_config["resume_experiment"] == "last":
+ if checkpoint == "last":
experiment = available_experiments[-1]
logger.debug(f"Resuming the latest experiment {experiment}")
else:
- experiment = experiment_config["resume_experiment"]
+ experiment = checkpoint
if not str(experiment_dir / experiment) in available_experiments:
raise FileNotFoundError("Experiment does not exist.")
+ logger.debug(f"Resuming the from experiment {checkpoint}")
experiment_dir = experiment_dir / experiment
@@ -71,14 +73,10 @@ def create_experiment_dir(experiment_config: Dict) -> Path:
return experiment_dir, log_dir, model_dir
-def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]:
+def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dict]:
"""Loads all modules and arguments."""
- # Import the data loader arguments.
- train_args = experiment_config.get("train_args", {})
-
# Load the dataset module.
dataset_args = experiment_config.get("dataset", {})
- dataset_args["train_args"]["batch_size"] = train_args["batch_size"]
datasets_module = importlib.import_module("text_recognizer.datasets")
dataset_ = getattr(datasets_module, dataset_args["type"])
@@ -102,21 +100,24 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
network_args = experiment_config["network"].get("args", {})
# Criterion
- if experiment_config["criterion"]["type"] in CUSTOM_LOSSES:
- criterion_ = getattr(losses, experiment_config["criterion"]["type"])
- criterion_args = experiment_config["criterion"].get("args", {})
+ if experiment_config["criterion"]["type"] in custom_loss_module.__all__:
+ criterion_ = getattr(custom_loss_module, experiment_config["criterion"]["type"])
else:
criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"])
- criterion_args = experiment_config["criterion"].get("args", {})
+ criterion_args = experiment_config["criterion"].get("args", {}) or {}
# Optimizers
- optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"])
+ if experiment_config["optimizer"]["type"] == "AdaBelief":
+ warnings.filterwarnings("ignore", category=UserWarning)
+ optimizer_ = getattr(adabelief_pytorch, experiment_config["optimizer"]["type"])
+ else:
+ optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"])
optimizer_args = experiment_config["optimizer"].get("args", {})
# Learning rate scheduler
lr_scheduler_ = None
lr_scheduler_args = None
- if experiment_config["lr_scheduler"] is not None:
+ if "lr_scheduler" in experiment_config:
lr_scheduler_ = getattr(
torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"]
)
@@ -146,10 +147,12 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
return model_class_, model_args
-def configure_callbacks(experiment_config: Dict, model_dir: Dict) -> CallbackList:
+def _configure_callbacks(experiment_config: Dict, model_dir: Path) -> CallbackList:
"""Configure a callback list for trainer."""
if "Checkpoint" in experiment_config["callback_args"]:
- experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = model_dir
+ experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = str(
+ model_dir
+ )
# Initializes callbacks.
callback_modules = importlib.import_module("training.trainer.callbacks")
@@ -161,13 +164,13 @@ def configure_callbacks(experiment_config: Dict, model_dir: Dict) -> CallbackLis
return callbacks
-def configure_logger(experiment_config: Dict, log_dir: Path) -> None:
+def _configure_logger(log_dir: Path, verbose: int = 0) -> None:
"""Configure the loguru logger for output to terminal and disk."""
# Have to remove default logger to get tqdm to work properly.
logger.remove()
# Fetch verbosity level.
- level = get_level(experiment_config)
+ level = _get_level(verbose)
logger.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=level)
logger.add(
@@ -176,20 +179,29 @@ def configure_logger(experiment_config: Dict, log_dir: Path) -> None:
)
-def save_config(experiment_dir: Path, experiment_config: Dict) -> None:
+def _save_config(experiment_dir: Path, experiment_config: Dict) -> None:
"""Copy config to experiment directory."""
config_path = experiment_dir / "config.yml"
with open(str(config_path), "w") as f:
yaml.dump(experiment_config, f)
-def load_from_checkpoint(model: Type[Model], log_dir: Path, model_dir: Path) -> None:
+def _load_from_checkpoint(
+ model: Type[Model], model_dir: Path, pretrained_weights: str = None,
+) -> None:
"""If checkpoint exists, load model weights and optimizers from checkpoint."""
# Get checkpoint path.
- checkpoint_path = model_dir / "last.pt"
+ if pretrained_weights is not None:
+ logger.info(f"Loading weights from {pretrained_weights}.")
+ checkpoint_path = (
+ EXPERIMENTS_DIRNAME / Path(pretrained_weights) / "model" / "best.pt"
+ )
+ else:
+ logger.info(f"Loading weights from {model_dir}.")
+ checkpoint_path = model_dir / "last.pt"
if checkpoint_path.exists():
- logger.info("Loading and resuming training from last checkpoint.")
- model.load_checkpoint(checkpoint_path)
+ logger.info("Loading and resuming training from checkpoint.")
+ model.load_from_checkpoint(checkpoint_path)
def evaluate_embedding(model: Type[Model]) -> Dict:
@@ -217,38 +229,50 @@ def evaluate_embedding(model: Type[Model]) -> Dict:
def run_experiment(
- experiment_config: Dict, save_weights: bool, device: str, use_wandb: bool = False
+ experiment_config: Dict,
+ save_weights: bool,
+ device: str,
+ use_wandb: bool,
+ train: bool,
+ test: bool,
+ verbose: int = 0,
+ checkpoint: Optional[str] = None,
+ pretrained_weights: Optional[str] = None,
) -> None:
"""Runs an experiment."""
logger.info(f"Experiment config: {json.dumps(experiment_config)}")
# Create new experiment.
- experiment_dir, log_dir, model_dir = create_experiment_dir(experiment_config)
+ experiment_dir, log_dir, model_dir = _create_experiment_dir(
+ experiment_config, checkpoint
+ )
# Make sure the log/model directory exists.
log_dir.mkdir(parents=True, exist_ok=True)
model_dir.mkdir(parents=True, exist_ok=True)
# Load the modules and model arguments.
- model_class_, model_args = load_modules_and_arguments(experiment_config)
+ model_class_, model_args = _load_modules_and_arguments(experiment_config)
# Initializes the model with experiment config.
model = model_class_(**model_args, device=device)
- callbacks = configure_callbacks(experiment_config, model_dir)
+ callbacks = _configure_callbacks(experiment_config, model_dir)
# Setup logger.
- configure_logger(experiment_config, log_dir)
+ _configure_logger(log_dir, verbose)
# Load from checkpoint if resuming an experiment.
- if experiment_config["resume_experiment"] is not None:
- load_from_checkpoint(model, log_dir, model_dir)
+ resume = False
+ if checkpoint is not None or pretrained_weights is not None:
+ resume = True
+ _load_from_checkpoint(model, model_dir, pretrained_weights)
logger.info(f"The class mapping is {model.mapping}")
# Initializes Weights & Biases
if use_wandb:
- wandb.init(project="text-recognizer", config=experiment_config)
+ wandb.init(project="text-recognizer", config=experiment_config, resume=resume)
# Lets W&B save the model and track the gradients and optional parameters.
wandb.watch(model.network)
@@ -265,23 +289,30 @@ def run_experiment(
experiment_config["device"] = device
# Save the config used in the experiment folder.
- save_config(experiment_dir, experiment_config)
+ _save_config(experiment_dir, experiment_config)
+
+ # Prints a summary of the network in terminal.
+ model.summary(experiment_config["train_args"]["input_shape"])
# Load trainer.
trainer = Trainer(
- max_epochs=experiment_config["train_args"]["max_epochs"], callbacks=callbacks,
+ max_epochs=experiment_config["train_args"]["max_epochs"],
+ callbacks=callbacks,
+ transformer_model=experiment_config["train_args"]["transformer_model"],
+ max_norm=experiment_config["train_args"]["max_norm"],
)
# Train the model.
- trainer.fit(model)
+ if train:
+ trainer.fit(model)
# Run inference over test set.
- if experiment_config["test"]:
+ if test:
logger.info("Loading checkpoint with the best weights.")
model.load_from_checkpoint(model_dir / "best.pt")
logger.info("Running inference on test set.")
- if experiment_config["criterion"]["type"] in CUSTOM_LOSSES:
+ if experiment_config["criterion"]["type"] == "EmbeddingLoss":
logger.info("Evaluating embedding.")
score = evaluate_embedding(model)
else:
@@ -313,7 +344,26 @@ def run_experiment(
@click.option(
"--nowandb", is_flag=False, help="If true, do not use wandb for this run."
)
-def run_cli(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None:
+@click.option("--test", is_flag=True, help="If true, test the model.")
+@click.option("-v", "--verbose", count=True)
+@click.option("--checkpoint", type=str, help="Path to the experiment.")
+@click.option(
+ "--pretrained_weights", type=str, help="Path to pretrained model weights."
+)
+@click.option(
+ "--notrain", is_flag=False, is_eager=True, help="Do not train the model.",
+)
+def run_cli(
+ experiment_config: str,
+ gpu: int,
+ save: bool,
+ nowandb: bool,
+ notrain: bool,
+ test: bool,
+ verbose: int,
+ checkpoint: Optional[str] = None,
+ pretrained_weights: Optional[str] = None,
+) -> None:
"""Run experiment."""
if gpu < 0:
gpu_manager = GPUManager(True)
@@ -322,7 +372,18 @@ def run_cli(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None
experiment_config = json.loads(experiment_config)
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu}"
- run_experiment(experiment_config, save, device, use_wandb=not nowandb)
+
+ run_experiment(
+ experiment_config,
+ save,
+ device,
+ use_wandb=not nowandb,
+ train=not notrain,
+ test=test,
+ verbose=verbose,
+ checkpoint=checkpoint,
+ pretrained_weights=pretrained_weights,
+ )
if __name__ == "__main__":