summaryrefslogtreecommitdiff
path: root/src/training
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-08-09 23:24:02 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-08-09 23:24:02 +0200
commit53677be4ec14854ea4881b0d78730e0414c8dedd (patch)
tree56eaace5e9906c7d408b6a251ca100b5c8b4e991 /src/training
parent125d5da5fb845d03bda91426e172bca7f537584a (diff)
Working bash scripts etc.
Diffstat (limited to 'src/training')
-rw-r--r--src/training/callbacks/wandb_callbacks.py4
-rw-r--r--src/training/experiments/sample_experiment.yml96
-rw-r--r--src/training/population_based_training/__init__.py1
-rw-r--r--src/training/population_based_training/population_based_training.py1
-rw-r--r--src/training/prepare_experiments.py13
-rw-r--r--src/training/run_experiment.py22
-rw-r--r--src/training/train.py4
7 files changed, 105 insertions, 36 deletions
diff --git a/src/training/callbacks/wandb_callbacks.py b/src/training/callbacks/wandb_callbacks.py
index f64cbe1..6ada6df 100644
--- a/src/training/callbacks/wandb_callbacks.py
+++ b/src/training/callbacks/wandb_callbacks.py
@@ -72,7 +72,7 @@ class WandbImageLogger(Callback):
def set_model(self, model: Type[Model]) -> None:
"""Sets the model and extracts validation images from the dataset."""
self.model = model
- data_loader = self.model.data_loaders("val")
+ data_loader = self.model.data_loaders["val"]
if self.example_indices is None:
self.example_indices = np.random.randint(
0, len(data_loader.dataset.data), self.num_examples
@@ -86,7 +86,7 @@ class WandbImageLogger(Callback):
for i, image in enumerate(self.val_images):
image = self.transforms(image)
pred, conf = self.model.predict_on_image(image)
- ground_truth = self.model._mapping[self.val_targets[i]]
+ ground_truth = self.model.mapper(int(self.val_targets[i]))
caption = f"Prediction: {pred} Confidence: {conf:.3f} Ground Truth: {ground_truth}"
images.append(wandb.Image(image, caption=caption))
diff --git a/src/training/experiments/sample_experiment.yml b/src/training/experiments/sample_experiment.yml
index 70edb63..57198f1 100644
--- a/src/training/experiments/sample_experiment.yml
+++ b/src/training/experiments/sample_experiment.yml
@@ -1,28 +1,30 @@
experiment_group: Sample Experiments
experiments:
- - dataloader: EmnistDataLoaders
- data_loader_args:
- splits: [train, val]
+ - dataset: EmnistDataset
+ dataset_args:
sample_to_balance: true
subsample_fraction: null
transform: null
target_transform: null
+ seed: 4711
+ data_loader_args:
+ splits: [train, val]
batch_size: 256
shuffle: true
num_workers: 8
cuda: true
- seed: 4711
model: CharacterModel
metrics: [accuracy]
- # network: MLP
- # network_args:
- # input_size: 784
- # output_size: 62
- # num_layers: 3
- network: LeNet
+ network: MLP
network_args:
- input_size: [28, 28]
+ input_size: 784
output_size: 62
+ num_layers: 3
+ activation_fn: GELU
+ # network: LeNet
+ # network_args:
+ # output_size: 62
+ # activation_fn: GELU
train_args:
batch_size: 256
epochs: 16
@@ -66,5 +68,75 @@ experiments:
num_examples: 4
OneCycleLR:
null
- verbosity: 1 # 0, 1, 2
+ verbosity: 2 # 0, 1, 2
resume_experiment: null
+ # - dataset: EmnistDataset
+ # dataset_args:
+ # sample_to_balance: true
+ # subsample_fraction: null
+ # transform: null
+ # target_transform: null
+ # seed: 4711
+ # data_loader_args:
+ # splits: [train, val]
+ # batch_size: 256
+ # shuffle: true
+ # num_workers: 8
+ # cuda: true
+ # model: CharacterModel
+ # metrics: [accuracy]
+ # # network: MLP
+ # # network_args:
+ # # input_size: 784
+ # # output_size: 62
+ # # num_layers: 3
+ # # activation_fn: GELU
+ # network: LeNet
+ # network_args:
+ # output_size: 62
+ # activation_fn: GELU
+ # train_args:
+ # batch_size: 256
+ # epochs: 16
+ # criterion: CrossEntropyLoss
+ # criterion_args:
+ # weight: null
+ # ignore_index: -100
+ # reduction: mean
+ # # optimizer: RMSprop
+ # # optimizer_args:
+ # # lr: 1.e-3
+ # # alpha: 0.9
+ # # eps: 1.e-7
+ # # momentum: 0
+ # # weight_decay: 0
+ # # centered: false
+ # optimizer: AdamW
+ # optimizer_args:
+ # lr: 1.e-2
+ # betas: [0.9, 0.999]
+ # eps: 1.e-08
+ # weight_decay: 0
+ # amsgrad: false
+ # # lr_scheduler: null
+ # lr_scheduler: OneCycleLR
+ # lr_scheduler_args:
+ # max_lr: 1.e-3
+ # epochs: 16
+ # callbacks: [Checkpoint, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR]
+ # callback_args:
+ # Checkpoint:
+ # monitor: val_accuracy
+ # EarlyStopping:
+ # monitor: val_loss
+ # min_delta: 0.0
+ # patience: 3
+ # mode: min
+ # WandbCallback:
+ # log_batch_frequency: 10
+ # WandbImageLogger:
+ # num_examples: 4
+ # OneCycleLR:
+ # null
+ # verbosity: 2 # 0, 1, 2
+ # resume_experiment: null
diff --git a/src/training/population_based_training/__init__.py b/src/training/population_based_training/__init__.py
new file mode 100644
index 0000000..868d739
--- /dev/null
+++ b/src/training/population_based_training/__init__.py
@@ -0,0 +1 @@
+"""TBC."""
diff --git a/src/training/population_based_training/population_based_training.py b/src/training/population_based_training/population_based_training.py
new file mode 100644
index 0000000..868d739
--- /dev/null
+++ b/src/training/population_based_training/population_based_training.py
@@ -0,0 +1 @@
+"""TBC."""
diff --git a/src/training/prepare_experiments.py b/src/training/prepare_experiments.py
index 5a665b3..97c0304 100644
--- a/src/training/prepare_experiments.py
+++ b/src/training/prepare_experiments.py
@@ -16,19 +16,8 @@ def run_experiments(experiments_filename: str) -> None:
for index in range(num_experiments):
experiment_config = experiments_config["experiments"][index]
experiment_config["experiment_group"] = experiments_config["experiment_group"]
- cmd = f"poetry run run-experiment --gpu=-1 --save --experiment_config={json.dumps(experiment_config)}"
+ cmd = f"python training/run_experiment.py --gpu=-1 --save --experiment_config='{json.dumps(experiment_config)}'"
print(cmd)
- run(
- [
- "poetry",
- "run",
- "run-experiment",
- "--gpu=-1",
- "--save",
- f"--experiment_config={json.dumps(experiment_config)}",
- ],
- check=True,
- )
@click.command()
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index c133ce5..d278dc2 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -58,12 +58,17 @@ def create_experiment_dir(model: Callable, experiment_config: Dict) -> Path:
return experiment_dir
+def check_args(args: Dict) -> Dict:
+ """Checks that the arguments are not None."""
+ return args or {}
+
+
def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]:
"""Loads all modules and arguments."""
- # Import the data loader module and arguments.
- datasets_module = importlib.import_module("text_recognizer.datasets")
- data_loader_ = getattr(datasets_module, experiment_config["dataloader"])
+ # Import the data loader arguments.
data_loader_args = experiment_config.get("data_loader_args", {})
+ data_loader_args["dataset"] = experiment_config["dataset"]
+ data_loader_args["dataset_args"] = experiment_config.get("dataset_args", {})
# Import the model module and model arguments.
models_module = importlib.import_module("text_recognizer.models")
@@ -90,10 +95,12 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
# Callbacks
callback_modules = importlib.import_module("training.callbacks")
- callbacks = []
- for callback in experiment_config["callbacks"]:
- args = experiment_config["callback_args"][callback] or {}
- callbacks.append(getattr(callback_modules, callback)(**args))
+ callbacks = [
+ getattr(callback_modules, callback)(
+ **check_args(experiment_config["callback_args"][callback])
+ )
+ for callback in experiment_config["callbacks"]
+ ]
# Learning rate scheduler
if experiment_config["lr_scheduler"] is not None:
@@ -106,7 +113,6 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
lr_scheduler_args = None
model_args = {
- "data_loader": data_loader_,
"data_loader_args": data_loader_args,
"metrics": metric_fns_,
"network_fn": network_fn_,
diff --git a/src/training/train.py b/src/training/train.py
index 3334c2e..aaa0430 100644
--- a/src/training/train.py
+++ b/src/training/train.py
@@ -106,7 +106,7 @@ class Trainer:
# Running average for the loss.
loss_avg = RunningAverage()
- data_loader = self.model.data_loaders("train")
+ data_loader = self.model.data_loaders["train"]
with tqdm(
total=len(data_loader),
@@ -164,7 +164,7 @@ class Trainer:
self.model.eval()
# Running average for the loss.
- data_loader = self.model.data_loaders("val")
+ data_loader = self.model.data_loaders["val"]
# Running average for the loss.
loss_avg = RunningAverage()