diff options
Diffstat (limited to 'src/training/run_experiment.py')
-rw-r--r-- | src/training/run_experiment.py | 45 |
1 files changed, 35 insertions, 10 deletions
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index 0b29ce9..c133ce5 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -12,8 +12,10 @@ import click from loguru import logger import torch from tqdm import tqdm +from training.callbacks import CallbackList from training.gpu_manager import GPUManager from training.train import Trainer +import wandb import yaml @@ -48,9 +50,8 @@ def create_experiment_dir(model: Callable, experiment_config: Dict) -> Path: logger.debug(f"Resuming the latest experiment {experiment}") else: experiment = experiment_config["resume_experiment"] - assert ( - str(experiment_dir / experiment) in available_experiments - ), "Experiment does not exist." + if not str(experiment_dir / experiment) in available_experiments: + raise FileNotFoundError("Experiment does not exist.") logger.debug(f"Resuming the experiment {experiment}") experiment_dir = experiment_dir / experiment @@ -87,6 +88,13 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] optimizer_ = getattr(torch.optim, experiment_config["optimizer"]) optimizer_args = experiment_config.get("optimizer_args", {}) + # Callbacks + callback_modules = importlib.import_module("training.callbacks") + callbacks = [] + for callback in experiment_config["callbacks"]: + args = experiment_config["callback_args"][callback] or {} + callbacks.append(getattr(callback_modules, callback)(**args)) + # Learning rate scheduler if experiment_config["lr_scheduler"] is not None: lr_scheduler_ = getattr( @@ -111,7 +119,7 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] "lr_scheduler_args": lr_scheduler_args, } - return model_class_, model_args + return model_class_, model_args, callbacks def run_experiment( @@ -120,11 +128,14 @@ def run_experiment( """Runs an experiment.""" # Load the modules and model arguments. - model_class_, model_args = load_modules_and_arguments(experiment_config) + model_class_, model_args, callbacks = load_modules_and_arguments(experiment_config) # Initializes the model with experiment config. model = model_class_(**model_args, device=device) + # Instantiate a CallbackList. + callbacks = CallbackList(model, callbacks) + # Create new experiment. experiment_dir = create_experiment_dir(model, experiment_config) @@ -132,6 +143,9 @@ def run_experiment( log_dir = experiment_dir / "log" model_dir = experiment_dir / "model" + # Set the model dir to be able to save checkpoints. + model.model_dir = model_dir + # Get checkpoint path. checkpoint_path = model_dir / "last.pt" if not checkpoint_path.exists(): @@ -162,6 +176,13 @@ def run_experiment( logger.info(f"The class mapping is {model.mapping}") + # Initializes Weights & Biases + if use_wandb: + wandb.init(project="text-recognizer", config=experiment_config) + + # Lets W&B save the model and track the gradients and optional parameters. + wandb.watch(model.network) + # PÅ•ints a summary of the network in terminal. model.summary() @@ -181,21 +202,26 @@ def run_experiment( with open(str(config_path), "w") as f: yaml.dump(experiment_config, f) - # TODO: wandb trainer = Trainer( model=model, model_dir=model_dir, - epochs=experiment_config["train_args"]["epochs"], - val_metric=experiment_config["train_args"]["val_metric"], + train_args=experiment_config["train_args"], + callbacks=callbacks, checkpoint_path=checkpoint_path, ) trainer.fit() + logger.info("Loading checkpoint with the best weights.") + model.load_checkpoint(model_dir / "best.pt") + score = trainer.validate() logger.info(f"Validation set evaluation: {score}") + if use_wandb: + wandb.log({"validation_metric": score["val_accuracy"]}) + if save_weights: model.save_weights(model_dir) @@ -220,12 +246,11 @@ def main(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None: if gpu < 0: gpu_manager = GPUManager(True) gpu = gpu_manager.get_free_gpu() - device = "cuda:" + str(gpu) experiment_config = json.loads(experiment_config) os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu}" - run_experiment(experiment_config, save, device, nowandb) + run_experiment(experiment_config, save, device, use_wandb=not nowandb) if __name__ == "__main__": |