diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-20 00:14:27 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-09-20 00:14:27 +0200 |
commit | e181195a699d7fa237f256d90ab4dedffc03d405 (patch) | |
tree | 6d8d50731a7267c56f7bf3ed5ecec3990c0e55a5 /src/training/run_experiment.py | |
parent | 3b06ef615a8db67a03927576e0c12fbfb2501f5f (diff) |
Minor bug fixes etc.
Diffstat (limited to 'src/training/run_experiment.py')
-rw-r--r-- | src/training/run_experiment.py | 79 |
1 files changed, 46 insertions, 33 deletions
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index 286b0c6..a347d9f 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -10,6 +10,7 @@ from typing import Callable, Dict, List, Tuple, Type import click from loguru import logger +import numpy as np import torch from tqdm import tqdm from training.gpu_manager import GPUManager @@ -20,11 +21,12 @@ import yaml from text_recognizer.models import Model +from text_recognizer.networks import losses EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" - +CUSTOM_LOSSES = ["EmbeddingLoss"] DEFAULT_TRAIN_ARGS = {"batch_size": 64, "epochs": 16} @@ -69,21 +71,6 @@ def create_experiment_dir(experiment_config: Dict) -> Path: return experiment_dir, log_dir, model_dir -def check_args(args: Dict, train_args: Dict) -> Dict: - """Checks that the arguments are not None.""" - args = args or {} - - # I just want to set total epochs in train args, instead of changing all parameter. - if "epochs" in args and args["epochs"] is None: - args["epochs"] = train_args["max_epochs"] - - # For CosineAnnealingLR. - if "T_max" in args and args["T_max"] is None: - args["T_max"] = train_args["max_epochs"] - - return args or {} - - def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]: """Loads all modules and arguments.""" # Import the data loader arguments. @@ -115,8 +102,12 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] network_args = experiment_config["network"].get("args", {}) # Criterion - criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"]) - criterion_args = experiment_config["criterion"].get("args", {}) + if experiment_config["criterion"]["type"] in CUSTOM_LOSSES: + criterion_ = getattr(losses, experiment_config["criterion"]["type"]) + criterion_args = experiment_config["criterion"].get("args", {}) + else: + criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"]) + criterion_args = experiment_config["criterion"].get("args", {}) # Optimizers optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"]) @@ -129,13 +120,11 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] lr_scheduler_ = getattr( torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"] ) - lr_scheduler_args = check_args( - experiment_config["lr_scheduler"].get("args", {}), train_args - ) + lr_scheduler_args = experiment_config["lr_scheduler"].get("args", {}) or {} # SWA scheduler. if "swa_args" in experiment_config: - swa_args = check_args(experiment_config.get("swa_args", {}), train_args) + swa_args = experiment_config.get("swa_args", {}) or {} else: swa_args = None @@ -159,19 +148,15 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] def configure_callbacks(experiment_config: Dict, model_dir: Dict) -> CallbackList: """Configure a callback list for trainer.""" - train_args = experiment_config.get("train_args", {}) - if "Checkpoint" in experiment_config["callback_args"]: experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = model_dir - # Callbacks + # Initializes callbacks. callback_modules = importlib.import_module("training.trainer.callbacks") - callbacks = [ - getattr(callback_modules, callback)( - **check_args(experiment_config["callback_args"][callback], train_args) - ) - for callback in experiment_config["callbacks"] - ] + callbacks = [] + for callback in experiment_config["callbacks"]: + args = experiment_config["callback_args"][callback] or {} + callbacks.append(getattr(callback_modules, callback)(**args)) return callbacks @@ -207,11 +192,35 @@ def load_from_checkpoint(model: Type[Model], log_dir: Path, model_dir: Path) -> model.load_checkpoint(checkpoint_path) +def evaluate_embedding(model: Type[Model]) -> Dict: + """Evaluates the embedding space.""" + from pytorch_metric_learning import testers + from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator + + accuracy_calculator = AccuracyCalculator( + include=("mean_average_precision_at_r",), k=10 + ) + + def get_all_embeddings(model: Type[Model]) -> Tuple: + tester = testers.BaseTester() + return tester.get_all_embeddings(model.test_dataset, model.network) + + embeddings, labels = get_all_embeddings(model) + logger.info("Computing embedding accuracy") + accuracies = accuracy_calculator.get_accuracy( + embeddings, embeddings, np.squeeze(labels), np.squeeze(labels), True + ) + logger.info( + f"Test set accuracy (MAP@10) = {accuracies['mean_average_precision_at_r']}" + ) + return accuracies + + def run_experiment( experiment_config: Dict, save_weights: bool, device: str, use_wandb: bool = False ) -> None: """Runs an experiment.""" - logger.info(f"Experiment config: {json.dumps(experiment_config, indent=2)}") + logger.info(f"Experiment config: {json.dumps(experiment_config)}") # Create new experiment. experiment_dir, log_dir, model_dir = create_experiment_dir(experiment_config) @@ -272,7 +281,11 @@ def run_experiment( model.load_from_checkpoint(model_dir / "best.pt") logger.info("Running inference on test set.") - score = trainer.test(model) + if experiment_config["criterion"]["type"] in CUSTOM_LOSSES: + logger.info("Evaluating embedding.") + score = evaluate_embedding(model) + else: + score = trainer.test(model) logger.info(f"Test set evaluation: {score}") |