summaryrefslogtreecommitdiff
path: root/src/training/run_experiment.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-09-20 00:14:27 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-09-20 00:14:27 +0200
commite181195a699d7fa237f256d90ab4dedffc03d405 (patch)
tree6d8d50731a7267c56f7bf3ed5ecec3990c0e55a5 /src/training/run_experiment.py
parent3b06ef615a8db67a03927576e0c12fbfb2501f5f (diff)
Minor bug fixes etc.
Diffstat (limited to 'src/training/run_experiment.py')
-rw-r--r--src/training/run_experiment.py79
1 files changed, 46 insertions, 33 deletions
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index 286b0c6..a347d9f 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -10,6 +10,7 @@ from typing import Callable, Dict, List, Tuple, Type
import click
from loguru import logger
+import numpy as np
import torch
from tqdm import tqdm
from training.gpu_manager import GPUManager
@@ -20,11 +21,12 @@ import yaml
from text_recognizer.models import Model
+from text_recognizer.networks import losses
EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
-
+CUSTOM_LOSSES = ["EmbeddingLoss"]
DEFAULT_TRAIN_ARGS = {"batch_size": 64, "epochs": 16}
@@ -69,21 +71,6 @@ def create_experiment_dir(experiment_config: Dict) -> Path:
return experiment_dir, log_dir, model_dir
-def check_args(args: Dict, train_args: Dict) -> Dict:
- """Checks that the arguments are not None."""
- args = args or {}
-
- # I just want to set total epochs in train args, instead of changing all parameter.
- if "epochs" in args and args["epochs"] is None:
- args["epochs"] = train_args["max_epochs"]
-
- # For CosineAnnealingLR.
- if "T_max" in args and args["T_max"] is None:
- args["T_max"] = train_args["max_epochs"]
-
- return args or {}
-
-
def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]:
"""Loads all modules and arguments."""
# Import the data loader arguments.
@@ -115,8 +102,12 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
network_args = experiment_config["network"].get("args", {})
# Criterion
- criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"])
- criterion_args = experiment_config["criterion"].get("args", {})
+ if experiment_config["criterion"]["type"] in CUSTOM_LOSSES:
+ criterion_ = getattr(losses, experiment_config["criterion"]["type"])
+ criterion_args = experiment_config["criterion"].get("args", {})
+ else:
+ criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"])
+ criterion_args = experiment_config["criterion"].get("args", {})
# Optimizers
optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"])
@@ -129,13 +120,11 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
lr_scheduler_ = getattr(
torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"]
)
- lr_scheduler_args = check_args(
- experiment_config["lr_scheduler"].get("args", {}), train_args
- )
+ lr_scheduler_args = experiment_config["lr_scheduler"].get("args", {}) or {}
# SWA scheduler.
if "swa_args" in experiment_config:
- swa_args = check_args(experiment_config.get("swa_args", {}), train_args)
+ swa_args = experiment_config.get("swa_args", {}) or {}
else:
swa_args = None
@@ -159,19 +148,15 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
def configure_callbacks(experiment_config: Dict, model_dir: Dict) -> CallbackList:
"""Configure a callback list for trainer."""
- train_args = experiment_config.get("train_args", {})
-
if "Checkpoint" in experiment_config["callback_args"]:
experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = model_dir
- # Callbacks
+ # Initializes callbacks.
callback_modules = importlib.import_module("training.trainer.callbacks")
- callbacks = [
- getattr(callback_modules, callback)(
- **check_args(experiment_config["callback_args"][callback], train_args)
- )
- for callback in experiment_config["callbacks"]
- ]
+ callbacks = []
+ for callback in experiment_config["callbacks"]:
+ args = experiment_config["callback_args"][callback] or {}
+ callbacks.append(getattr(callback_modules, callback)(**args))
return callbacks
@@ -207,11 +192,35 @@ def load_from_checkpoint(model: Type[Model], log_dir: Path, model_dir: Path) ->
model.load_checkpoint(checkpoint_path)
+def evaluate_embedding(model: Type[Model]) -> Dict:
+ """Evaluates the embedding space."""
+ from pytorch_metric_learning import testers
+ from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
+
+ accuracy_calculator = AccuracyCalculator(
+ include=("mean_average_precision_at_r",), k=10
+ )
+
+ def get_all_embeddings(model: Type[Model]) -> Tuple:
+ tester = testers.BaseTester()
+ return tester.get_all_embeddings(model.test_dataset, model.network)
+
+ embeddings, labels = get_all_embeddings(model)
+ logger.info("Computing embedding accuracy")
+ accuracies = accuracy_calculator.get_accuracy(
+ embeddings, embeddings, np.squeeze(labels), np.squeeze(labels), True
+ )
+ logger.info(
+ f"Test set accuracy (MAP@10) = {accuracies['mean_average_precision_at_r']}"
+ )
+ return accuracies
+
+
def run_experiment(
experiment_config: Dict, save_weights: bool, device: str, use_wandb: bool = False
) -> None:
"""Runs an experiment."""
- logger.info(f"Experiment config: {json.dumps(experiment_config, indent=2)}")
+ logger.info(f"Experiment config: {json.dumps(experiment_config)}")
# Create new experiment.
experiment_dir, log_dir, model_dir = create_experiment_dir(experiment_config)
@@ -272,7 +281,11 @@ def run_experiment(
model.load_from_checkpoint(model_dir / "best.pt")
logger.info("Running inference on test set.")
- score = trainer.test(model)
+ if experiment_config["criterion"]["type"] in CUSTOM_LOSSES:
+ logger.info("Evaluating embedding.")
+ score = evaluate_embedding(model)
+ else:
+ score = trainer.test(model)
logger.info(f"Test set evaluation: {score}")