summaryrefslogtreecommitdiff
path: root/src/training/run_experiment.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-08-03 23:33:34 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-08-03 23:33:34 +0200
commit07dd14116fe1d8148fb614b160245287533620fc (patch)
tree63395d88b17a14ad453c52889fcf541e6cbbdd3e /src/training/run_experiment.py
parent704451318eb6b0b600ab314cb5aabfac82416bda (diff)
Working Emnist lines dataset.
Diffstat (limited to 'src/training/run_experiment.py')
-rw-r--r--src/training/run_experiment.py45
1 files changed, 35 insertions, 10 deletions
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index 0b29ce9..c133ce5 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -12,8 +12,10 @@ import click
from loguru import logger
import torch
from tqdm import tqdm
+from training.callbacks import CallbackList
from training.gpu_manager import GPUManager
from training.train import Trainer
+import wandb
import yaml
@@ -48,9 +50,8 @@ def create_experiment_dir(model: Callable, experiment_config: Dict) -> Path:
logger.debug(f"Resuming the latest experiment {experiment}")
else:
experiment = experiment_config["resume_experiment"]
- assert (
- str(experiment_dir / experiment) in available_experiments
- ), "Experiment does not exist."
+ if not str(experiment_dir / experiment) in available_experiments:
+ raise FileNotFoundError("Experiment does not exist.")
logger.debug(f"Resuming the experiment {experiment}")
experiment_dir = experiment_dir / experiment
@@ -87,6 +88,13 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
optimizer_ = getattr(torch.optim, experiment_config["optimizer"])
optimizer_args = experiment_config.get("optimizer_args", {})
+ # Callbacks
+ callback_modules = importlib.import_module("training.callbacks")
+ callbacks = []
+ for callback in experiment_config["callbacks"]:
+ args = experiment_config["callback_args"][callback] or {}
+ callbacks.append(getattr(callback_modules, callback)(**args))
+
# Learning rate scheduler
if experiment_config["lr_scheduler"] is not None:
lr_scheduler_ = getattr(
@@ -111,7 +119,7 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
"lr_scheduler_args": lr_scheduler_args,
}
- return model_class_, model_args
+ return model_class_, model_args, callbacks
def run_experiment(
@@ -120,11 +128,14 @@ def run_experiment(
"""Runs an experiment."""
# Load the modules and model arguments.
- model_class_, model_args = load_modules_and_arguments(experiment_config)
+ model_class_, model_args, callbacks = load_modules_and_arguments(experiment_config)
# Initializes the model with experiment config.
model = model_class_(**model_args, device=device)
+ # Instantiate a CallbackList.
+ callbacks = CallbackList(model, callbacks)
+
# Create new experiment.
experiment_dir = create_experiment_dir(model, experiment_config)
@@ -132,6 +143,9 @@ def run_experiment(
log_dir = experiment_dir / "log"
model_dir = experiment_dir / "model"
+ # Set the model dir to be able to save checkpoints.
+ model.model_dir = model_dir
+
# Get checkpoint path.
checkpoint_path = model_dir / "last.pt"
if not checkpoint_path.exists():
@@ -162,6 +176,13 @@ def run_experiment(
logger.info(f"The class mapping is {model.mapping}")
+ # Initializes Weights & Biases
+ if use_wandb:
+ wandb.init(project="text-recognizer", config=experiment_config)
+
+ # Lets W&B save the model and track the gradients and optional parameters.
+ wandb.watch(model.network)
+
# PÅ•ints a summary of the network in terminal.
model.summary()
@@ -181,21 +202,26 @@ def run_experiment(
with open(str(config_path), "w") as f:
yaml.dump(experiment_config, f)
- # TODO: wandb
trainer = Trainer(
model=model,
model_dir=model_dir,
- epochs=experiment_config["train_args"]["epochs"],
- val_metric=experiment_config["train_args"]["val_metric"],
+ train_args=experiment_config["train_args"],
+ callbacks=callbacks,
checkpoint_path=checkpoint_path,
)
trainer.fit()
+ logger.info("Loading checkpoint with the best weights.")
+ model.load_checkpoint(model_dir / "best.pt")
+
score = trainer.validate()
logger.info(f"Validation set evaluation: {score}")
+ if use_wandb:
+ wandb.log({"validation_metric": score["val_accuracy"]})
+
if save_weights:
model.save_weights(model_dir)
@@ -220,12 +246,11 @@ def main(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None:
if gpu < 0:
gpu_manager = GPUManager(True)
gpu = gpu_manager.get_free_gpu()
-
device = "cuda:" + str(gpu)
experiment_config = json.loads(experiment_config)
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu}"
- run_experiment(experiment_config, save, device, nowandb)
+ run_experiment(experiment_config, save, device, use_wandb=not nowandb)
if __name__ == "__main__":