diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-07-05 22:27:08 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-07-05 22:27:08 +0200 |
commit | 5a78fc2e33c28968a69d033cb10d638f4f63fed1 (patch) | |
tree | e11cea8366c848e5500f85968ee5369ff8d96b00 /src/training/gpu_manager.py | |
parent | 7c4de6d88664d2ea1b084f316a11896dde3e1150 (diff) |
Working on getting experiment loop.
Diffstat (limited to 'src/training/gpu_manager.py')
-rw-r--r-- | src/training/gpu_manager.py | 62 |
1 files changed, 62 insertions, 0 deletions
diff --git a/src/training/gpu_manager.py b/src/training/gpu_manager.py new file mode 100644 index 0000000..ce1b3dd --- /dev/null +++ b/src/training/gpu_manager.py @@ -0,0 +1,62 @@ +"""GPUManager class.""" +import os +import time +from typing import Optional + +import gpustat +from loguru import logger +import numpy as np +from redlock import Redlock + + +GPU_LOCK_TIMEOUT = 5000 # ms + + +class GPUManager: + """Class for allocating GPUs.""" + + def __init__(self, verbose: bool = False) -> None: + """Initializes Redlock manager.""" + self.lock_manager = Redlock([{"host": "localhost", "port": 6379, "db": 0}]) + self.verbose = verbose + + def get_free_gpu(self) -> int: + """Gets a free GPU. + + If some GPUs are available, try reserving one by checking out an exclusive redis lock. + If none available or can not get lock, sleep and check again. + + Returns: + int: The gpu index. + + """ + while True: + gpu_index = self._get_free_gpu() + if gpu_index is not None: + return gpu_index + + if self.verbose: + logger.debug(f"pid {os.getpid()} sleeping") + time.sleep(GPU_LOCK_TIMEOUT / 1000) + + def _get_free_gpu(self) -> Optional[int]: + """Fetches an available GPU index.""" + try: + available_gpu_indices = [ + gpu.index + for gpu in gpustat.GPUStatCollection.new_query() + if gpu.memory_used < 0.5 * gpu.memory_total + ] + except Exception as e: + logger.debug(f"Got the following exception: {e}") + return None + + if available_gpu_indices: + gpu_index = np.random.choice(available_gpu_indices) + if self.verbose: + logger.debug(f"pid {os.getpid()} picking gpu {gpu_index}") + if self.lock_manager.lock(f"gpu_{gpu_index}", GPU_LOCK_TIMEOUT): + return int(gpu_index) + if self.verbose: + logger.debug(f"pid {os.getpid()} could not get lock.") + return None |