diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-08 12:41:04 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-08 12:41:04 +0100 |
commit | beeaef529e7c893a3475fe27edc880e283373725 (patch) | |
tree | 59eb72562bf7a5a9470c2586e6280600ad94f1ae /src/training/run_experiment.py | |
parent | 4d7713746eb936832e84852e90292936b933e87d (diff) |
Trying to get the CNNTransformer to work, but it is hard.
Diffstat (limited to 'src/training/run_experiment.py')
-rw-r--r-- | src/training/run_experiment.py | 28 |
1 files changed, 17 insertions, 11 deletions
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index c0f969d..0510d5c 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -73,7 +73,7 @@ def _create_experiment_dir( return experiment_dir, log_dir, model_dir -def _load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]: +def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dict]: """Loads all modules and arguments.""" # Load the dataset module. dataset_args = experiment_config.get("dataset", {}) @@ -104,7 +104,7 @@ def _load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict criterion_ = getattr(custom_loss_module, experiment_config["criterion"]["type"]) else: criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"]) - criterion_args = experiment_config["criterion"].get("args", {}) + criterion_args = experiment_config["criterion"].get("args", {}) or {} # Optimizers if experiment_config["optimizer"]["type"] == "AdaBelief": @@ -187,18 +187,20 @@ def _save_config(experiment_dir: Path, experiment_config: Dict) -> None: def _load_from_checkpoint( - model: Type[Model], log_dir: Path, model_dir: Path, pretrained_weights: str = None + model: Type[Model], model_dir: Path, pretrained_weights: str = None, ) -> None: """If checkpoint exists, load model weights and optimizers from checkpoint.""" # Get checkpoint path. if pretrained_weights is not None: logger.info(f"Loading weights from {pretrained_weights}.") - checkpoint_path = Path(pretrained_weights) / "model" / "best.pt" + checkpoint_path = ( + EXPERIMENTS_DIRNAME / Path(pretrained_weights) / "model" / "best.pt" + ) else: logger.info(f"Loading weights from {model_dir}.") checkpoint_path = model_dir / "last.pt" if checkpoint_path.exists(): - logger.info("Loading and resuming training from last checkpoint.") + logger.info("Loading and resuming training from checkpoint.") model.load_from_checkpoint(checkpoint_path) @@ -230,9 +232,9 @@ def run_experiment( experiment_config: Dict, save_weights: bool, device: str, - use_wandb: bool = False, - train: bool = True, - test: bool = False, + use_wandb: bool, + train: bool, + test: bool, verbose: int = 0, checkpoint: Optional[str] = None, pretrained_weights: Optional[str] = None, @@ -264,7 +266,7 @@ def run_experiment( resume = False if checkpoint is not None or pretrained_weights is not None: resume = True - _load_from_checkpoint(model, log_dir, model_dir, pretrained_weights) + _load_from_checkpoint(model, model_dir, pretrained_weights) logger.info(f"The class mapping is {model.mapping}") @@ -297,6 +299,7 @@ def run_experiment( max_epochs=experiment_config["train_args"]["max_epochs"], callbacks=callbacks, transformer_model=experiment_config["train_args"]["transformer_model"], + max_norm=experiment_config["train_args"]["max_norm"], ) # Train the model. @@ -309,7 +312,7 @@ def run_experiment( model.load_from_checkpoint(model_dir / "best.pt") logger.info("Running inference on test set.") - if experiment_config["criterion"]["type"] in custom_loss_module.__all__: + if experiment_config["criterion"]["type"] == "EmbeddingLoss": logger.info("Evaluating embedding.") score = evaluate_embedding(model) else: @@ -341,13 +344,15 @@ def run_experiment( @click.option( "--nowandb", is_flag=False, help="If true, do not use wandb for this run." ) -@click.option("--notrain", is_flag=False, help="Do not train the model.") @click.option("--test", is_flag=True, help="If true, test the model.") @click.option("-v", "--verbose", count=True) @click.option("--checkpoint", type=str, help="Path to the experiment.") @click.option( "--pretrained_weights", type=str, help="Path to pretrained model weights." ) +@click.option( + "--notrain", is_flag=False, is_eager=True, help="Do not train the model.", +) def run_cli( experiment_config: str, gpu: int, @@ -367,6 +372,7 @@ def run_cli( experiment_config = json.loads(experiment_config) os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu}" + run_experiment( experiment_config, save, |