summaryrefslogtreecommitdiff
path: root/src/training
diff options
context:
space:
mode:
Diffstat (limited to 'src/training')
-rw-r--r--src/training/gpu_manager.py62
-rw-r--r--src/training/prepare_experiments.py35
-rw-r--r--src/training/run_experiment.py74
-rw-r--r--src/training/train.py17
4 files changed, 181 insertions, 7 deletions
diff --git a/src/training/gpu_manager.py b/src/training/gpu_manager.py
new file mode 100644
index 0000000..ce1b3dd
--- /dev/null
+++ b/src/training/gpu_manager.py
@@ -0,0 +1,62 @@
+"""GPUManager class."""
+import os
+import time
+from typing import Optional
+
+import gpustat
+from loguru import logger
+import numpy as np
+from redlock import Redlock
+
+
+GPU_LOCK_TIMEOUT = 5000 # ms
+
+
+class GPUManager:
+ """Class for allocating GPUs."""
+
+ def __init__(self, verbose: bool = False) -> None:
+ """Initializes Redlock manager."""
+ self.lock_manager = Redlock([{"host": "localhost", "port": 6379, "db": 0}])
+ self.verbose = verbose
+
+ def get_free_gpu(self) -> int:
+ """Gets a free GPU.
+
+ If some GPUs are available, try reserving one by checking out an exclusive redis lock.
+ If none available or can not get lock, sleep and check again.
+
+ Returns:
+ int: The gpu index.
+
+ """
+ while True:
+ gpu_index = self._get_free_gpu()
+ if gpu_index is not None:
+ return gpu_index
+
+ if self.verbose:
+ logger.debug(f"pid {os.getpid()} sleeping")
+ time.sleep(GPU_LOCK_TIMEOUT / 1000)
+
+ def _get_free_gpu(self) -> Optional[int]:
+ """Fetches an available GPU index."""
+ try:
+ available_gpu_indices = [
+ gpu.index
+ for gpu in gpustat.GPUStatCollection.new_query()
+ if gpu.memory_used < 0.5 * gpu.memory_total
+ ]
+ except Exception as e:
+ logger.debug(f"Got the following exception: {e}")
+ return None
+
+ if available_gpu_indices:
+ gpu_index = np.random.choice(available_gpu_indices)
+ if self.verbose:
+ logger.debug(f"pid {os.getpid()} picking gpu {gpu_index}")
+ if self.lock_manager.lock(f"gpu_{gpu_index}", GPU_LOCK_TIMEOUT):
+ return int(gpu_index)
+ if self.verbose:
+ logger.debug(f"pid {os.getpid()} could not get lock.")
+ return None
diff --git a/src/training/prepare_experiments.py b/src/training/prepare_experiments.py
new file mode 100644
index 0000000..1ab8f00
--- /dev/null
+++ b/src/training/prepare_experiments.py
@@ -0,0 +1,35 @@
+"""Run a experiment from a config file."""
+import json
+
+import click
+from loguru import logger
+import yaml
+
+
+def run_experiment(experiment_filename: str) -> None:
+ """Run experiment from file."""
+ with open(experiment_filename) as f:
+ experiments_config = yaml.safe_load(f)
+ num_experiments = len(experiments_config["experiments"])
+ for index in range(num_experiments):
+ experiment_config = experiments_config["experiments"][index]
+ experiment_config["experiment_group"] = experiments_config["experiment_group"]
+ print(
+ f"python training/run_experiment.py --gpu=-1 '{json.dumps(experiment_config)}'"
+ )
+
+
+@click.command()
+@click.option(
+ "--experiments_filename",
+ required=True,
+ type=str,
+ help="Filename of Yaml file of experiments to run.",
+)
+def main(experiment_filename: str) -> None:
+ """Parse command-line arguments and run experiments from provided file."""
+ run_experiment(experiment_filename)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index 8033f47..8296e59 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -1 +1,75 @@
"""Script to run experiments."""
+import importlib
+import os
+from typing import Dict
+
+import click
+import torch
+from training.train import Trainer
+
+
+def run_experiment(
+ experiment_config: Dict, save_weights: bool, gpu_index: int, use_wandb: bool = False
+) -> None:
+ """Short summary."""
+ # Import the data loader module and arguments.
+ datasets_module = importlib.import_module("text_recognizer.datasets")
+ data_loader_ = getattr(datasets_module, experiment_config["dataloader"])
+ data_loader_args = experiment_config.get("data_loader_args", {})
+
+ # Import the model module and model arguments.
+ models_module = importlib.import_module("text_recognizer.models")
+ model_class_ = getattr(models_module, experiment_config["model"])
+
+ # Import metric.
+ metric_fn_ = getattr(models_module, experiment_config["metric"])
+
+ # Import network module and arguments.
+ network_module = importlib.import_module("text_recognizer.networks")
+ network_fn_ = getattr(network_module, experiment_config["network"])
+ network_args = experiment_config.get("network_args", {})
+
+ # Criterion
+ criterion_ = getattr(torch.nn, experiment_config["criterion"])
+ criterion_args = experiment_config.get("criterion_args", {})
+
+ # Optimizer
+ optimizer_ = getattr(torch.optim, experiment_config["optimizer"])
+ optimizer_args = experiment_config.get("optimizer_args", {})
+
+ # Learning rate scheduler
+ lr_scheduler_ = None
+ lr_scheduler_args = None
+ if experiment_config["lr_scheduler"] is not None:
+ lr_scheduler_ = getattr(
+ torch.optim.lr_scheduler, experiment_config["lr_scheduler"]
+ )
+ lr_scheduler_args = experiment_config.get("lr_scheduler_args", {})
+
+ # Device
+ # TODO fix gpu manager
+ device = None
+
+ model = model_class_(
+ network_fn=network_fn_,
+ network_args=network_args,
+ data_loader=data_loader_,
+ data_loader_args=data_loader_args,
+ metrics=metric_fn_,
+ criterion=criterion_,
+ criterion_args=criterion_args,
+ optimizer=optimizer_,
+ optimizer_args=optimizer_args,
+ lr_scheduler=lr_scheduler_,
+ lr_scheduler_args=lr_scheduler_args,
+ device=device,
+ )
+
+ # TODO: Fix checkpoint path and wandb
+ trainer = Trainer(
+ model=model,
+ epochs=experiment_config["epochs"],
+ val_metric=experiment_config["metric"],
+ )
+
+ trainer.fit()
diff --git a/src/training/train.py b/src/training/train.py
index 783de02..4a452b6 100644
--- a/src/training/train.py
+++ b/src/training/train.py
@@ -9,6 +9,8 @@ import numpy as np
import torch
from tqdm import tqdm, trange
from training.util import RunningAverage
+import wandb
+
torch.backends.cudnn.benchmark = True
np.random.seed(4711)
@@ -30,6 +32,7 @@ class Trainer:
epochs: int,
val_metric: str = "accuracy",
checkpoint_path: Optional[Path] = None,
+ use_wandb: Optional[bool] = False,
) -> None:
"""Initialization of the Trainer.
@@ -38,6 +41,7 @@ class Trainer:
epochs (int): Number of epochs to train.
val_metric (str): The validation metric to evaluate the model on. Defaults to "accuracy".
checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None.
+ use_wandb (Optional[bool]): Sync training to wandb.
"""
self.model = model
@@ -48,13 +52,16 @@ class Trainer:
if self.checkpoint_path is not None:
self.start_epoch = self.model.load_checkpoint(self.checkpoint_path)
+ if use_wandb:
+ # TODO implement wandb logging.
+ pass
+
self.val_metric = val_metric
self.best_val_metric = 0.0
logger.add(self.model.name + "_{time}.log")
def train(self) -> None:
"""Training loop."""
- # TODO add summary
# Set model to traning mode.
self.model.train()
@@ -93,10 +100,6 @@ class Trainer:
# Perform updates using calculated gradients.
self.model.optimizer.step()
- # Update the learning rate scheduler.
- if self.model.lr_scheduler is not None:
- self.model.lr_scheduler.step()
-
# Compute metrics.
loss_avg.update(loss.item())
output = output.data.cpu()
@@ -174,8 +177,8 @@ class Trainer:
return metrics_mean
- def run(self) -> None:
- """Training and evaluation loop."""
+ def fit(self) -> None:
+ """Runs the training and evaluation loop."""
# Create new experiment.
EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True)
experiment = datetime.now().strftime("%m%d_%H%M%S")