summaryrefslogtreecommitdiff
path: root/src/training/gpu_manager.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/gpu_manager.py')
-rw-r--r--src/training/gpu_manager.py62
1 files changed, 62 insertions, 0 deletions
diff --git a/src/training/gpu_manager.py b/src/training/gpu_manager.py
new file mode 100644
index 0000000..ce1b3dd
--- /dev/null
+++ b/src/training/gpu_manager.py
@@ -0,0 +1,62 @@
+"""GPUManager class."""
+import os
+import time
+from typing import Optional
+
+import gpustat
+from loguru import logger
+import numpy as np
+from redlock import Redlock
+
+
+GPU_LOCK_TIMEOUT = 5000 # ms
+
+
+class GPUManager:
+ """Class for allocating GPUs."""
+
+ def __init__(self, verbose: bool = False) -> None:
+ """Initializes Redlock manager."""
+ self.lock_manager = Redlock([{"host": "localhost", "port": 6379, "db": 0}])
+ self.verbose = verbose
+
+ def get_free_gpu(self) -> int:
+ """Gets a free GPU.
+
+ If some GPUs are available, try reserving one by checking out an exclusive redis lock.
+ If none available or can not get lock, sleep and check again.
+
+ Returns:
+ int: The gpu index.
+
+ """
+ while True:
+ gpu_index = self._get_free_gpu()
+ if gpu_index is not None:
+ return gpu_index
+
+ if self.verbose:
+ logger.debug(f"pid {os.getpid()} sleeping")
+ time.sleep(GPU_LOCK_TIMEOUT / 1000)
+
+ def _get_free_gpu(self) -> Optional[int]:
+ """Fetches an available GPU index."""
+ try:
+ available_gpu_indices = [
+ gpu.index
+ for gpu in gpustat.GPUStatCollection.new_query()
+ if gpu.memory_used < 0.5 * gpu.memory_total
+ ]
+ except Exception as e:
+ logger.debug(f"Got the following exception: {e}")
+ return None
+
+ if available_gpu_indices:
+ gpu_index = np.random.choice(available_gpu_indices)
+ if self.verbose:
+ logger.debug(f"pid {os.getpid()} picking gpu {gpu_index}")
+ if self.lock_manager.lock(f"gpu_{gpu_index}", GPU_LOCK_TIMEOUT):
+ return int(gpu_index)
+ if self.verbose:
+ logger.debug(f"pid {os.getpid()} could not get lock.")
+ return None