summaryrefslogtreecommitdiff
path: root/src/training/run_experiment.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 12:41:04 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 12:41:04 +0100
commitbeeaef529e7c893a3475fe27edc880e283373725 (patch)
tree59eb72562bf7a5a9470c2586e6280600ad94f1ae /src/training/run_experiment.py
parent4d7713746eb936832e84852e90292936b933e87d (diff)
Trying to get the CNNTransformer to work, but it is hard.
Diffstat (limited to 'src/training/run_experiment.py')
-rw-r--r--src/training/run_experiment.py28
1 files changed, 17 insertions, 11 deletions
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index c0f969d..0510d5c 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -73,7 +73,7 @@ def _create_experiment_dir(
return experiment_dir, log_dir, model_dir
-def _load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]:
+def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dict]:
"""Loads all modules and arguments."""
# Load the dataset module.
dataset_args = experiment_config.get("dataset", {})
@@ -104,7 +104,7 @@ def _load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict
criterion_ = getattr(custom_loss_module, experiment_config["criterion"]["type"])
else:
criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"])
- criterion_args = experiment_config["criterion"].get("args", {})
+ criterion_args = experiment_config["criterion"].get("args", {}) or {}
# Optimizers
if experiment_config["optimizer"]["type"] == "AdaBelief":
@@ -187,18 +187,20 @@ def _save_config(experiment_dir: Path, experiment_config: Dict) -> None:
def _load_from_checkpoint(
- model: Type[Model], log_dir: Path, model_dir: Path, pretrained_weights: str = None
+ model: Type[Model], model_dir: Path, pretrained_weights: str = None,
) -> None:
"""If checkpoint exists, load model weights and optimizers from checkpoint."""
# Get checkpoint path.
if pretrained_weights is not None:
logger.info(f"Loading weights from {pretrained_weights}.")
- checkpoint_path = Path(pretrained_weights) / "model" / "best.pt"
+ checkpoint_path = (
+ EXPERIMENTS_DIRNAME / Path(pretrained_weights) / "model" / "best.pt"
+ )
else:
logger.info(f"Loading weights from {model_dir}.")
checkpoint_path = model_dir / "last.pt"
if checkpoint_path.exists():
- logger.info("Loading and resuming training from last checkpoint.")
+ logger.info("Loading and resuming training from checkpoint.")
model.load_from_checkpoint(checkpoint_path)
@@ -230,9 +232,9 @@ def run_experiment(
experiment_config: Dict,
save_weights: bool,
device: str,
- use_wandb: bool = False,
- train: bool = True,
- test: bool = False,
+ use_wandb: bool,
+ train: bool,
+ test: bool,
verbose: int = 0,
checkpoint: Optional[str] = None,
pretrained_weights: Optional[str] = None,
@@ -264,7 +266,7 @@ def run_experiment(
resume = False
if checkpoint is not None or pretrained_weights is not None:
resume = True
- _load_from_checkpoint(model, log_dir, model_dir, pretrained_weights)
+ _load_from_checkpoint(model, model_dir, pretrained_weights)
logger.info(f"The class mapping is {model.mapping}")
@@ -297,6 +299,7 @@ def run_experiment(
max_epochs=experiment_config["train_args"]["max_epochs"],
callbacks=callbacks,
transformer_model=experiment_config["train_args"]["transformer_model"],
+ max_norm=experiment_config["train_args"]["max_norm"],
)
# Train the model.
@@ -309,7 +312,7 @@ def run_experiment(
model.load_from_checkpoint(model_dir / "best.pt")
logger.info("Running inference on test set.")
- if experiment_config["criterion"]["type"] in custom_loss_module.__all__:
+ if experiment_config["criterion"]["type"] == "EmbeddingLoss":
logger.info("Evaluating embedding.")
score = evaluate_embedding(model)
else:
@@ -341,13 +344,15 @@ def run_experiment(
@click.option(
"--nowandb", is_flag=False, help="If true, do not use wandb for this run."
)
-@click.option("--notrain", is_flag=False, help="Do not train the model.")
@click.option("--test", is_flag=True, help="If true, test the model.")
@click.option("-v", "--verbose", count=True)
@click.option("--checkpoint", type=str, help="Path to the experiment.")
@click.option(
"--pretrained_weights", type=str, help="Path to pretrained model weights."
)
+@click.option(
+ "--notrain", is_flag=False, is_eager=True, help="Do not train the model.",
+)
def run_cli(
experiment_config: str,
gpu: int,
@@ -367,6 +372,7 @@ def run_cli(
experiment_config = json.loads(experiment_config)
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu}"
+
run_experiment(
experiment_config,
save,