diff options
| author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-07-05 22:27:08 +0200 | 
|---|---|---|
| committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-07-05 22:27:08 +0200 | 
| commit | 5a78fc2e33c28968a69d033cb10d638f4f63fed1 (patch) | |
| tree | e11cea8366c848e5500f85968ee5369ff8d96b00 /src/training/gpu_manager.py | |
| parent | 7c4de6d88664d2ea1b084f316a11896dde3e1150 (diff) | |
Working on getting experiment loop.
Diffstat (limited to 'src/training/gpu_manager.py')
| -rw-r--r-- | src/training/gpu_manager.py | 62 | 
1 files changed, 62 insertions, 0 deletions
diff --git a/src/training/gpu_manager.py b/src/training/gpu_manager.py new file mode 100644 index 0000000..ce1b3dd --- /dev/null +++ b/src/training/gpu_manager.py @@ -0,0 +1,62 @@ +"""GPUManager class.""" +import os +import time +from typing import Optional + +import gpustat +from loguru import logger +import numpy as np +from redlock import Redlock + + +GPU_LOCK_TIMEOUT = 5000  # ms + + +class GPUManager: +    """Class for allocating GPUs.""" + +    def __init__(self, verbose: bool = False) -> None: +        """Initializes Redlock manager.""" +        self.lock_manager = Redlock([{"host": "localhost", "port": 6379, "db": 0}]) +        self.verbose = verbose + +    def get_free_gpu(self) -> int: +        """Gets a free GPU. + +        If some GPUs are available, try reserving one by checking out an exclusive redis lock. +        If none available or can not get lock, sleep and check again. + +        Returns: +            int: The gpu index. + +        """ +        while True: +            gpu_index = self._get_free_gpu() +            if gpu_index is not None: +                return gpu_index + +            if self.verbose: +                logger.debug(f"pid {os.getpid()} sleeping") +            time.sleep(GPU_LOCK_TIMEOUT / 1000) + +    def _get_free_gpu(self) -> Optional[int]: +        """Fetches an available GPU index.""" +        try: +            available_gpu_indices = [ +                gpu.index +                for gpu in gpustat.GPUStatCollection.new_query() +                if gpu.memory_used < 0.5 * gpu.memory_total +            ] +        except Exception as e: +            logger.debug(f"Got the following exception: {e}") +            return None + +        if available_gpu_indices: +            gpu_index = np.random.choice(available_gpu_indices) +            if self.verbose: +                logger.debug(f"pid {os.getpid()} picking gpu {gpu_index}") +            if self.lock_manager.lock(f"gpu_{gpu_index}", GPU_LOCK_TIMEOUT): +                return int(gpu_index) +            if self.verbose: +                logger.debug(f"pid {os.getpid()} could not get lock.") +        return None  |