diff options
Diffstat (limited to 'src/training')
-rw-r--r-- | src/training/gpu_manager.py | 62 | ||||
-rw-r--r-- | src/training/prepare_experiments.py | 35 | ||||
-rw-r--r-- | src/training/run_experiment.py | 74 | ||||
-rw-r--r-- | src/training/train.py | 17 |
4 files changed, 181 insertions, 7 deletions
diff --git a/src/training/gpu_manager.py b/src/training/gpu_manager.py new file mode 100644 index 0000000..ce1b3dd --- /dev/null +++ b/src/training/gpu_manager.py @@ -0,0 +1,62 @@ +"""GPUManager class.""" +import os +import time +from typing import Optional + +import gpustat +from loguru import logger +import numpy as np +from redlock import Redlock + + +GPU_LOCK_TIMEOUT = 5000 # ms + + +class GPUManager: + """Class for allocating GPUs.""" + + def __init__(self, verbose: bool = False) -> None: + """Initializes Redlock manager.""" + self.lock_manager = Redlock([{"host": "localhost", "port": 6379, "db": 0}]) + self.verbose = verbose + + def get_free_gpu(self) -> int: + """Gets a free GPU. + + If some GPUs are available, try reserving one by checking out an exclusive redis lock. + If none available or can not get lock, sleep and check again. + + Returns: + int: The gpu index. + + """ + while True: + gpu_index = self._get_free_gpu() + if gpu_index is not None: + return gpu_index + + if self.verbose: + logger.debug(f"pid {os.getpid()} sleeping") + time.sleep(GPU_LOCK_TIMEOUT / 1000) + + def _get_free_gpu(self) -> Optional[int]: + """Fetches an available GPU index.""" + try: + available_gpu_indices = [ + gpu.index + for gpu in gpustat.GPUStatCollection.new_query() + if gpu.memory_used < 0.5 * gpu.memory_total + ] + except Exception as e: + logger.debug(f"Got the following exception: {e}") + return None + + if available_gpu_indices: + gpu_index = np.random.choice(available_gpu_indices) + if self.verbose: + logger.debug(f"pid {os.getpid()} picking gpu {gpu_index}") + if self.lock_manager.lock(f"gpu_{gpu_index}", GPU_LOCK_TIMEOUT): + return int(gpu_index) + if self.verbose: + logger.debug(f"pid {os.getpid()} could not get lock.") + return None diff --git a/src/training/prepare_experiments.py b/src/training/prepare_experiments.py new file mode 100644 index 0000000..1ab8f00 --- /dev/null +++ b/src/training/prepare_experiments.py @@ -0,0 +1,35 @@ +"""Run a experiment from a config file.""" +import json + +import click +from loguru import logger +import yaml + + +def run_experiment(experiment_filename: str) -> None: + """Run experiment from file.""" + with open(experiment_filename) as f: + experiments_config = yaml.safe_load(f) + num_experiments = len(experiments_config["experiments"]) + for index in range(num_experiments): + experiment_config = experiments_config["experiments"][index] + experiment_config["experiment_group"] = experiments_config["experiment_group"] + print( + f"python training/run_experiment.py --gpu=-1 '{json.dumps(experiment_config)}'" + ) + + +@click.command() +@click.option( + "--experiments_filename", + required=True, + type=str, + help="Filename of Yaml file of experiments to run.", +) +def main(experiment_filename: str) -> None: + """Parse command-line arguments and run experiments from provided file.""" + run_experiment(experiment_filename) + + +if __name__ == "__main__": + main() diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index 8033f47..8296e59 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -1 +1,75 @@ """Script to run experiments.""" +import importlib +import os +from typing import Dict + +import click +import torch +from training.train import Trainer + + +def run_experiment( + experiment_config: Dict, save_weights: bool, gpu_index: int, use_wandb: bool = False +) -> None: + """Short summary.""" + # Import the data loader module and arguments. + datasets_module = importlib.import_module("text_recognizer.datasets") + data_loader_ = getattr(datasets_module, experiment_config["dataloader"]) + data_loader_args = experiment_config.get("data_loader_args", {}) + + # Import the model module and model arguments. + models_module = importlib.import_module("text_recognizer.models") + model_class_ = getattr(models_module, experiment_config["model"]) + + # Import metric. + metric_fn_ = getattr(models_module, experiment_config["metric"]) + + # Import network module and arguments. + network_module = importlib.import_module("text_recognizer.networks") + network_fn_ = getattr(network_module, experiment_config["network"]) + network_args = experiment_config.get("network_args", {}) + + # Criterion + criterion_ = getattr(torch.nn, experiment_config["criterion"]) + criterion_args = experiment_config.get("criterion_args", {}) + + # Optimizer + optimizer_ = getattr(torch.optim, experiment_config["optimizer"]) + optimizer_args = experiment_config.get("optimizer_args", {}) + + # Learning rate scheduler + lr_scheduler_ = None + lr_scheduler_args = None + if experiment_config["lr_scheduler"] is not None: + lr_scheduler_ = getattr( + torch.optim.lr_scheduler, experiment_config["lr_scheduler"] + ) + lr_scheduler_args = experiment_config.get("lr_scheduler_args", {}) + + # Device + # TODO fix gpu manager + device = None + + model = model_class_( + network_fn=network_fn_, + network_args=network_args, + data_loader=data_loader_, + data_loader_args=data_loader_args, + metrics=metric_fn_, + criterion=criterion_, + criterion_args=criterion_args, + optimizer=optimizer_, + optimizer_args=optimizer_args, + lr_scheduler=lr_scheduler_, + lr_scheduler_args=lr_scheduler_args, + device=device, + ) + + # TODO: Fix checkpoint path and wandb + trainer = Trainer( + model=model, + epochs=experiment_config["epochs"], + val_metric=experiment_config["metric"], + ) + + trainer.fit() diff --git a/src/training/train.py b/src/training/train.py index 783de02..4a452b6 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -9,6 +9,8 @@ import numpy as np import torch from tqdm import tqdm, trange from training.util import RunningAverage +import wandb + torch.backends.cudnn.benchmark = True np.random.seed(4711) @@ -30,6 +32,7 @@ class Trainer: epochs: int, val_metric: str = "accuracy", checkpoint_path: Optional[Path] = None, + use_wandb: Optional[bool] = False, ) -> None: """Initialization of the Trainer. @@ -38,6 +41,7 @@ class Trainer: epochs (int): Number of epochs to train. val_metric (str): The validation metric to evaluate the model on. Defaults to "accuracy". checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None. + use_wandb (Optional[bool]): Sync training to wandb. """ self.model = model @@ -48,13 +52,16 @@ class Trainer: if self.checkpoint_path is not None: self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) + if use_wandb: + # TODO implement wandb logging. + pass + self.val_metric = val_metric self.best_val_metric = 0.0 logger.add(self.model.name + "_{time}.log") def train(self) -> None: """Training loop.""" - # TODO add summary # Set model to traning mode. self.model.train() @@ -93,10 +100,6 @@ class Trainer: # Perform updates using calculated gradients. self.model.optimizer.step() - # Update the learning rate scheduler. - if self.model.lr_scheduler is not None: - self.model.lr_scheduler.step() - # Compute metrics. loss_avg.update(loss.item()) output = output.data.cpu() @@ -174,8 +177,8 @@ class Trainer: return metrics_mean - def run(self) -> None: - """Training and evaluation loop.""" + def fit(self) -> None: + """Runs the training and evaluation loop.""" # Create new experiment. EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True) experiment = datetime.now().strftime("%m%d_%H%M%S") |