summaryrefslogtreecommitdiff
path: root/src/training/run_experiment.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-09-08 23:14:23 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-09-08 23:14:23 +0200
commite1b504bca41a9793ed7e88ef14f2e2cbd85724f2 (patch)
tree70b482f890c9ad2be104f0bff8f2172e8411a2be /src/training/run_experiment.py
parentfe23001b6588e6e6e9e2c5a99b72f3445cf5206f (diff)
IAM datasets implemented.
Diffstat (limited to 'src/training/run_experiment.py')
-rw-r--r--src/training/run_experiment.py238
1 files changed, 143 insertions, 95 deletions
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index 8c063ff..4317d66 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -6,18 +6,19 @@ import json
import os
from pathlib import Path
import re
-from typing import Callable, Dict, Tuple, Type
+from typing import Callable, Dict, List, Tuple, Type
import click
from loguru import logger
import torch
from tqdm import tqdm
from training.gpu_manager import GPUManager
-from training.trainer.callbacks import CallbackList
+from training.trainer.callbacks import Callback, CallbackList
from training.trainer.train import Trainer
import wandb
import yaml
+
from text_recognizer.models import Model
@@ -37,10 +38,14 @@ def get_level(experiment_config: Dict) -> int:
return 10
-def create_experiment_dir(model: Type[Model], experiment_config: Dict) -> Path:
+def create_experiment_dir(experiment_config: Dict) -> Path:
"""Create new experiment."""
EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True)
- experiment_dir = EXPERIMENTS_DIRNAME / model.__name__
+ experiment_dir = EXPERIMENTS_DIRNAME / (
+ f"{experiment_config['model']}_"
+ + f"{experiment_config['dataset']['type']}_"
+ + f"{experiment_config['network']['type']}"
+ )
if experiment_config["resume_experiment"] is None:
experiment = datetime.now().strftime("%m%d_%H%M%S")
logger.debug(f"Creating a new experiment called {experiment}")
@@ -54,70 +59,89 @@ def create_experiment_dir(model: Type[Model], experiment_config: Dict) -> Path:
experiment = experiment_config["resume_experiment"]
if not str(experiment_dir / experiment) in available_experiments:
raise FileNotFoundError("Experiment does not exist.")
- logger.debug(f"Resuming the experiment {experiment}")
experiment_dir = experiment_dir / experiment
- return experiment_dir
+ # Create log and model directories.
+ log_dir = experiment_dir / "log"
+ model_dir = experiment_dir / "model"
+
+ return experiment_dir, log_dir, model_dir
-def check_args(args: Dict) -> Dict:
+
+def check_args(args: Dict, train_args: Dict) -> Dict:
"""Checks that the arguments are not None."""
+ args = args or {}
+
+ # I just want to set total epochs in train args, instead of changing all parameter.
+ if "epochs" in args and args["epochs"] is None:
+ args["epochs"] = train_args["max_epochs"]
+
+ # For CosineAnnealingLR.
+ if "T_max" in args and args["T_max"] is None:
+ args["T_max"] = train_args["max_epochs"]
+
return args or {}
def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]:
"""Loads all modules and arguments."""
# Import the data loader arguments.
- data_loader_args = experiment_config.get("data_loader_args", {})
train_args = experiment_config.get("train_args", {})
- data_loader_args["batch_size"] = train_args["batch_size"]
- data_loader_args["dataset"] = experiment_config["dataset"]
- data_loader_args["dataset_args"] = experiment_config.get("dataset_args", {})
+
+ # Load the dataset module.
+ dataset_args = experiment_config.get("dataset", {})
+ dataset_args["train_args"]["batch_size"] = train_args["batch_size"]
+ datasets_module = importlib.import_module("text_recognizer.datasets")
+ dataset_ = getattr(datasets_module, dataset_args["type"])
# Import the model module and model arguments.
models_module = importlib.import_module("text_recognizer.models")
model_class_ = getattr(models_module, experiment_config["model"])
# Import metrics.
- metric_fns_ = {
- metric: getattr(models_module, metric)
- for metric in experiment_config["metrics"]
- }
+ metric_fns_ = (
+ {
+ metric: getattr(models_module, metric)
+ for metric in experiment_config["metrics"]
+ }
+ if experiment_config["metrics"] is not None
+ else None
+ )
# Import network module and arguments.
network_module = importlib.import_module("text_recognizer.networks")
- network_fn_ = getattr(network_module, experiment_config["network"])
- network_args = experiment_config.get("network_args", {})
+ network_fn_ = getattr(network_module, experiment_config["network"]["type"])
+ network_args = experiment_config["network"].get("args", {})
# Criterion
- criterion_ = getattr(torch.nn, experiment_config["criterion"])
- criterion_args = experiment_config.get("criterion_args", {})
+ criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"])
+ criterion_args = experiment_config["criterion"].get("args", {})
- # Optimizer
- optimizer_ = getattr(torch.optim, experiment_config["optimizer"])
- optimizer_args = experiment_config.get("optimizer_args", {})
-
- # Callbacks
- callback_modules = importlib.import_module("training.trainer.callbacks")
- callbacks = [
- getattr(callback_modules, callback)(
- **check_args(experiment_config["callback_args"][callback])
- )
- for callback in experiment_config["callbacks"]
- ]
+ # Optimizers
+ optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"])
+ optimizer_args = experiment_config["optimizer"].get("args", {})
# Learning rate scheduler
+ lr_scheduler_ = None
+ lr_scheduler_args = None
if experiment_config["lr_scheduler"] is not None:
lr_scheduler_ = getattr(
- torch.optim.lr_scheduler, experiment_config["lr_scheduler"]
+ torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"]
+ )
+ lr_scheduler_args = check_args(
+ experiment_config["lr_scheduler"].get("args", {}), train_args
)
- lr_scheduler_args = experiment_config.get("lr_scheduler_args", {})
+
+ # SWA scheduler.
+ if "swa_args" in experiment_config:
+ swa_args = check_args(experiment_config.get("swa_args", {}), train_args)
else:
- lr_scheduler_ = None
- lr_scheduler_args = None
+ swa_args = None
model_args = {
- "data_loader_args": data_loader_args,
+ "dataset": dataset_,
+ "dataset_args": dataset_args,
"metrics": metric_fns_,
"network_fn": network_fn_,
"network_args": network_args,
@@ -127,43 +151,33 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
"optimizer_args": optimizer_args,
"lr_scheduler": lr_scheduler_,
"lr_scheduler_args": lr_scheduler_args,
+ "swa_args": swa_args,
}
- return model_class_, model_args, callbacks
-
+ return model_class_, model_args
-def run_experiment(
- experiment_config: Dict, save_weights: bool, device: str, use_wandb: bool = False
-) -> None:
- """Runs an experiment."""
-
- # Load the modules and model arguments.
- model_class_, model_args, callbacks = load_modules_and_arguments(experiment_config)
-
- # Initializes the model with experiment config.
- model = model_class_(**model_args, device=device)
- # Instantiate a CallbackList.
- callbacks = CallbackList(model, callbacks)
-
- # Create new experiment.
- experiment_dir = create_experiment_dir(model, experiment_config)
+def configure_callbacks(experiment_config: Dict, model_dir: Dict) -> CallbackList:
+ """Configure a callback list for trainer."""
+ train_args = experiment_config.get("train_args", {})
- # Create log and model directories.
- log_dir = experiment_dir / "log"
- model_dir = experiment_dir / "model"
+ if "Checkpoint" in experiment_config["callback_args"]:
+ experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = model_dir
- # Set the model dir to be able to save checkpoints.
- model.model_dir = model_dir
+ # Callbacks
+ callback_modules = importlib.import_module("training.trainer.callbacks")
+ callbacks = [
+ getattr(callback_modules, callback)(
+ **check_args(experiment_config["callback_args"][callback], train_args)
+ )
+ for callback in experiment_config["callbacks"]
+ ]
- # Get checkpoint path.
- checkpoint_path = model_dir / "last.pt"
- if not checkpoint_path.exists():
- checkpoint_path = None
+ return callbacks
- # Make sure the log directory exists.
- log_dir.mkdir(parents=True, exist_ok=True)
+def configure_logger(experiment_config: Dict, log_dir: Path) -> None:
+ """Configure the loguru logger for output to terminal and disk."""
# Have to remove default logger to get tqdm to work properly.
logger.remove()
@@ -176,13 +190,50 @@ def run_experiment(
format="{time:YYYY-MM-DD at HH:mm:ss} | {level} | {message}",
)
- if "cuda" in device:
- gpu_index = re.sub("[^0-9]+", "", device)
- logger.info(
- f"Running experiment with config {experiment_config} on GPU {gpu_index}"
- )
- else:
- logger.info(f"Running experiment with config {experiment_config} on CPU")
+
+def save_config(experiment_dir: Path, experiment_config: Dict) -> None:
+ """Copy config to experiment directory."""
+ config_path = experiment_dir / "config.yml"
+ with open(str(config_path), "w") as f:
+ yaml.dump(experiment_config, f)
+
+
+def load_from_checkpoint(model: Type[Model], log_dir: Path, model_dir: Path) -> None:
+ """If checkpoint exists, load model weights and optimizers from checkpoint."""
+ # Get checkpoint path.
+ checkpoint_path = model_dir / "last.pt"
+ if checkpoint_path.exists():
+ logger.info("Loading and resuming training from last checkpoint.")
+ model.load_checkpoint(checkpoint_path)
+
+
+def run_experiment(
+ experiment_config: Dict, save_weights: bool, device: str, use_wandb: bool = False
+) -> None:
+ """Runs an experiment."""
+ logger.info(f"Experiment config: {json.dumps(experiment_config)}")
+
+ # Create new experiment.
+ experiment_dir, log_dir, model_dir = create_experiment_dir(experiment_config)
+
+ # Make sure the log/model directory exists.
+ log_dir.mkdir(parents=True, exist_ok=True)
+ model_dir.mkdir(parents=True, exist_ok=True)
+
+ # Load the modules and model arguments.
+ model_class_, model_args = load_modules_and_arguments(experiment_config)
+
+ # Initializes the model with experiment config.
+ model = model_class_(**model_args, device=device)
+
+ callbacks = configure_callbacks(experiment_config, model_dir)
+
+ # Setup logger.
+ configure_logger(experiment_config, log_dir)
+
+ # Load from checkpoint if resuming an experiment.
+ if experiment_config["resume_experiment"] is not None:
+ load_from_checkpoint(model, log_dir, model_dir)
logger.info(f"The class mapping is {model.mapping}")
@@ -193,9 +244,6 @@ def run_experiment(
# Lets W&B save the model and track the gradients and optional parameters.
wandb.watch(model.network)
- # PÅ•ints a summary of the network in terminal.
- model.summary()
-
experiment_config["train_args"] = {
**DEFAULT_TRAIN_ARGS,
**experiment_config.get("train_args", {}),
@@ -208,41 +256,41 @@ def run_experiment(
experiment_config["device"] = device
# Save the config used in the experiment folder.
- config_path = experiment_dir / "config.yml"
- with open(str(config_path), "w") as f:
- yaml.dump(experiment_config, f)
+ save_config(experiment_dir, experiment_config)
- # Train the model.
+ # Load trainer.
trainer = Trainer(
- model=model,
- model_dir=model_dir,
- train_args=experiment_config["train_args"],
- callbacks=callbacks,
- checkpoint_path=checkpoint_path,
+ max_epochs=experiment_config["train_args"]["max_epochs"], callbacks=callbacks,
)
- trainer.fit()
+ # Train the model.
+ trainer.fit(model)
- logger.info("Loading checkpoint with the best weights.")
- model.load_checkpoint(model_dir / "best.pt")
+ # Run inference over test set.
+ if experiment_config["test"]:
+ logger.info("Loading checkpoint with the best weights.")
+ model.load_from_checkpoint(model_dir / "best.pt")
- score = trainer.validate()
+ logger.info("Running inference on test set.")
+ score = trainer.test(model)
- logger.info(f"Validation set evaluation: {score}")
+ logger.info(f"Test set evaluation: {score}")
- if use_wandb:
- wandb.log({"validation_metric": score["val_accuracy"]})
+ if use_wandb:
+ wandb.log(
+ {
+ experiment_config["test_metric"]: score[
+ experiment_config["test_metric"]
+ ]
+ }
+ )
if save_weights:
model.save_weights(model_dir)
@click.command()
-@click.option(
- "--experiment_config",
- type=str,
- help='Experiment JSON, e.g. \'{"dataloader": "EmnistDataLoader", "model": "CharacterModel", "network": "mlp"}\'',
-)
+@click.argument("experiment_config",)
@click.option("--gpu", type=int, default=0, help="Provide the index of the GPU to use.")
@click.option(
"--save",