summaryrefslogtreecommitdiff
path: root/src/training
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2021-01-07 20:10:54 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2021-01-07 20:10:54 +0100
commitff9a21d333f11a42e67c1963ed67de9c0fda87c9 (patch)
treeafee959135416fe92cf6df377e84fb0a9e9714a0 /src/training
parent25b5d6983d51e0e791b96a76beb7e49f392cd9a8 (diff)
Minor updates.
Diffstat (limited to 'src/training')
-rw-r--r--src/training/prepare_experiments.py2
-rw-r--r--src/training/run_experiment.py2
-rw-r--r--src/training/trainer/callbacks/wandb_callbacks.py4
3 files changed, 5 insertions, 3 deletions
diff --git a/src/training/prepare_experiments.py b/src/training/prepare_experiments.py
index 6e20bcd..21997af 100644
--- a/src/training/prepare_experiments.py
+++ b/src/training/prepare_experiments.py
@@ -14,7 +14,7 @@ def run_experiments(experiments_filename: str) -> None:
for index in range(num_experiments):
experiment_config = experiments_config["experiments"][index]
experiment_config["experiment_group"] = experiments_config["experiment_group"]
- cmd = f"python training/run_experiment.py --gpu=-1 --save '{json.dumps(experiment_config)}'"
+ cmd = f"poetry run run-experiment --gpu=-1 --save '{json.dumps(experiment_config)}'"
print(cmd)
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index 0167725..2c9a196 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -256,7 +256,7 @@ def run_experiment(
# Load from checkpoint if resuming an experiment.
resume = False
if checkpoint is not None or pretrained_weights is not None:
- resume = True
+ # resume = True
_load_from_checkpoint(model, model_dir, pretrained_weights)
logger.info(f"The class mapping is {model.mapping}")
diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
index df1fd8f..20414df 100644
--- a/src/training/trainer/callbacks/wandb_callbacks.py
+++ b/src/training/trainer/callbacks/wandb_callbacks.py
@@ -110,7 +110,9 @@ class WandbImageLogger(Callback):
if isinstance(self.targets[i], list):
ground_truth = "".join(
[
- self.model.mapper(int(target_index))
+ self.model.mapper(int(target_index) - 26)
+ if target_index > 35
+ else self.model.mapper(int(target_index))
for target_index in self.targets[i]
]
).rstrip("_")