diff options
-rw-r--r-- | README.md | 9 | ||||
-rw-r--r-- | poetry.lock | 67 | ||||
-rw-r--r-- | pyproject.toml | 2 | ||||
-rw-r--r-- | src/notebooks/Untitled.ipynb | 177 | ||||
-rw-r--r-- | src/text_recognizer/datasets/__init__.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/datasets/data_loader.py | 15 | ||||
-rw-r--r-- | src/text_recognizer/datasets/emnist_dataset.py | 258 | ||||
-rw-r--r-- | src/text_recognizer/models/base.py | 12 | ||||
-rw-r--r-- | src/text_recognizer/models/character_model.py | 8 | ||||
-rw-r--r-- | src/text_recognizer/models/metrics.py (renamed from src/text_recognizer/models/util.py) | 2 | ||||
-rw-r--r-- | src/training/gpu_manager.py | 62 | ||||
-rw-r--r-- | src/training/prepare_experiments.py | 35 | ||||
-rw-r--r-- | src/training/run_experiment.py | 74 | ||||
-rw-r--r-- | src/training/train.py | 17 |
14 files changed, 603 insertions, 137 deletions
@@ -3,6 +3,13 @@ Implementing the text recognizer project from the course ["Full Stack Deep Learn ## Setup ---- TBC + +## Todo +-[x] subsampling +-[] Be able to run experiments +-[] Train models +-[] Implement wandb +-[] Implement Bayesian hyperparameter search +-[] New models and datasets diff --git a/poetry.lock b/poetry.lock index 348987b..faf2e7b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -114,6 +114,17 @@ six = ">=1.9.0" webencodings = "*" [[package]] +category = "dev" +description = "A thin, practical wrapper around terminal coloring, styling, and positioning" +name = "blessings" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" +version = "1.7" + +[package.dependencies] +six = "*" + +[[package]] category = "main" description = "When they're not builtins, they're boltons." name = "boltons" @@ -421,6 +432,23 @@ version = "3.1.2" gitdb = ">=4.0.1,<5" [[package]] +category = "dev" +description = "An utility to monitor NVIDIA GPU status and usage" +name = "gpustat" +optional = false +python-versions = "*" +version = "0.6.0" + +[package.dependencies] +blessings = ">=1.6" +nvidia-ml-py3 = ">=7.352.0" +psutil = "*" +six = ">=1.7" + +[package.extras] +test = ["mock (>=2.0.0)", "pytest (<5.0)"] + +[[package]] category = "main" description = "GraphQL client for Python" name = "gql" @@ -1315,6 +1343,28 @@ version = "1.9.0" [[package]] category = "dev" +description = "Python client for Redis key-value store" +name = "redis" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*" +version = "3.5.3" + +[package.extras] +hiredis = ["hiredis (>=0.1.3)"] + +[[package]] +category = "dev" +description = "Redis locking mechanism" +name = "redlock-py" +optional = false +python-versions = "*" +version = "1.0.8" + +[package.dependencies] +redis = "*" + +[[package]] +category = "dev" description = "Alternative regular expression module, to replace re." name = "regex" optional = false @@ -1836,7 +1886,7 @@ docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"] testing = ["jaraco.itertools", "func-timeout"] [metadata] -content-hash = "4b4b531a4a45f81cf30cfdd45f34ef07a980689b5af2c99671e34c9ff9158836" +content-hash = "21a2ef803311c6b5f5bbba56f2baa1b303fd0dd742f46c87f06b1d6dd01765b4" python-versions = "^3.7" [metadata.files] @@ -1880,6 +1930,11 @@ bleach = [ {file = "bleach-3.1.5-py2.py3-none-any.whl", hash = "sha256:2bce3d8fab545a6528c8fa5d9f9ae8ebc85a56da365c7f85180bfe96a35ef22f"}, {file = "bleach-3.1.5.tar.gz", hash = "sha256:3c4c520fdb9db59ef139915a5db79f8b51bc2a7257ea0389f30c846883430a4b"}, ] +blessings = [ + {file = "blessings-1.7-py2-none-any.whl", hash = "sha256:caad5211e7ba5afe04367cdd4cfc68fa886e2e08f6f35e76b7387d2109ccea6e"}, + {file = "blessings-1.7-py3-none-any.whl", hash = "sha256:b1fdd7e7a675295630f9ae71527a8ebc10bfefa236b3d6aa4932ee4462c17ba3"}, + {file = "blessings-1.7.tar.gz", hash = "sha256:98e5854d805f50a5b58ac2333411b0482516a8210f23f43308baeb58d77c157d"}, +] boltons = [ {file = "boltons-20.1.0-py2.py3-none-any.whl", hash = "sha256:b3fc2b711f50cd975e726324d98e0bd5a324dd7e3b81d5e6a1b03c542d0c66c4"}, {file = "boltons-20.1.0.tar.gz", hash = "sha256:6e890b173c5f2dcb4ec62320b3799342ecb1a6a0b2253014455387665d62c213"}, @@ -2018,6 +2073,9 @@ gitpython = [ {file = "GitPython-3.1.2-py3-none-any.whl", hash = "sha256:da3b2cf819974789da34f95ac218ef99f515a928685db141327c09b73dd69c09"}, {file = "GitPython-3.1.2.tar.gz", hash = "sha256:864a47472548f3ba716ca202e034c1900f197c0fb3a08f641c20c3cafd15ed94"}, ] +gpustat = [ + {file = "gpustat-0.6.0.tar.gz", hash = "sha256:f69135080b2668b662822633312c2180002c10111597af9631bb02e042755b6c"}, +] gql = [ {file = "gql-0.2.0.tar.gz", hash = "sha256:ad0f0b8226428d727c8e1d1cac4e521d83ed024d814921bd55b8adb997dadf4b"}, ] @@ -2515,6 +2573,13 @@ qtpy = [ {file = "QtPy-1.9.0-py2.py3-none-any.whl", hash = "sha256:fa0b8363b363e89b2a6f49eddc162a04c0699ae95e109a6be3bb145a913190ea"}, {file = "QtPy-1.9.0.tar.gz", hash = "sha256:2db72c44b55d0fe1407be8fba35c838ad0d6d3bb81f23007886dc1fc0f459c8d"}, ] +redis = [ + {file = "redis-3.5.3-py2.py3-none-any.whl", hash = "sha256:432b788c4530cfe16d8d943a09d40ca6c16149727e4afe8c2c9d5580c59d9f24"}, + {file = "redis-3.5.3.tar.gz", hash = "sha256:0e7e0cfca8660dea8b7d5cd8c4f6c5e29e11f31158c0b0ae91a397f00e5a05a2"}, +] +redlock-py = [ + {file = "redlock-py-1.0.8.tar.gz", hash = "sha256:0b8722c4843ddeabc2fc1dd37c05859e0da29fbce3bd1f6ecc73c98396f139ac"}, +] regex = [ {file = "regex-2020.5.14-cp27-cp27m-win32.whl", hash = "sha256:e565569fc28e3ba3e475ec344d87ed3cd8ba2d575335359749298a0899fe122e"}, {file = "regex-2020.5.14-cp27-cp27m-win_amd64.whl", hash = "sha256:d466967ac8e45244b9dfe302bbe5e3337f8dc4dec8d7d10f5e950d83b140d33a"}, diff --git a/pyproject.toml b/pyproject.toml index 3afeee7..bc4f39d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,8 @@ typeguard = "^2.7.1" xdoctest = "^0.12.0" sphinx = "^3.0.4" jupyter = "^1.0.0" +gpustat = "^0.6.0" +redlock-py = "^1.0.8" [tool.coverage.report] fail_under = 50 diff --git a/src/notebooks/Untitled.ipynb b/src/notebooks/Untitled.ipynb new file mode 100644 index 0000000..1cb7acb --- /dev/null +++ b/src/notebooks/Untitled.ipynb @@ -0,0 +1,177 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "getattr(torch.optim.lr_scheduler, \"StepLR\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loss = getattr(torch.nn, \"L1Loss\")()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "input = torch.randn(3, 5, requires_grad=True)\n", + "target = torch.randn(3, 5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output = loss(input, target)\n", + "output.backward()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "s = 1.\n", + "if s is not None:\n", + " assert 0.0 < s < 1.0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class A:\n", + " @property\n", + " def __name__(self):\n", + " return \"adafa\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "a = A()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "a.__name__" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from training.gpu_manager import GPUManager" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "gpu_manager = GPUManager(True)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-05 21:16:55.100 | DEBUG | training.gpu_manager:_get_free_gpu:55 - pid 29777 picking gpu 0\n", + "2020-07-05 21:16:55.704 | DEBUG | training.gpu_manager:_get_free_gpu:59 - pid 29777 could not get lock.\n", + "2020-07-05 21:16:55.705 | DEBUG | training.gpu_manager:get_free_gpu:37 - pid 29777 sleeping\n", + "2020-07-05 21:17:00.722 | DEBUG | training.gpu_manager:_get_free_gpu:55 - pid 29777 picking gpu 0\n" + ] + }, + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gpu_manager.get_free_gpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.2" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index 929cfb9..aec5bf9 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -1,2 +1,2 @@ """Dataset modules.""" -from .data_loader import fetch_data_loader +from .emnist_dataset import EmnistDataLoader diff --git a/src/text_recognizer/datasets/data_loader.py b/src/text_recognizer/datasets/data_loader.py deleted file mode 100644 index fd55934..0000000 --- a/src/text_recognizer/datasets/data_loader.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Data loader collection.""" - -from typing import Dict - -from torch.utils.data import DataLoader - -from text_recognizer.datasets.emnist_dataset import fetch_emnist_data_loader - - -def fetch_data_loader(data_loader_args: Dict) -> DataLoader: - """Fetches the specified PyTorch data loader.""" - if data_loader_args.pop("name").lower() == "emnist": - return fetch_emnist_data_loader(data_loader_args) - else: - raise NotImplementedError diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index f9c8ffa..a17d7a9 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -24,7 +24,7 @@ class Transpose: return np.array(image).swapaxes(0, 1) -def save_emnist_essentials(emnsit_dataset: EMNIST) -> None: +def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None: """Extract and saves EMNIST essentials.""" labels = emnsit_dataset.classes labels.sort() @@ -45,111 +45,163 @@ def download_emnist() -> None: save_emnist_essentials(dataset) -def load_emnist_mapping() -> Dict: +def load_emnist_mapping() -> Dict[int, str]: """Load the EMNIST mapping.""" with open(str(ESSENTIALS_FILENAME)) as f: essentials = json.load(f) return dict(essentials["mapping"]) -def _sample_to_balance(dataset: EMNIST, seed: int = 4711) -> None: - """Because the dataset is not balanced, we take at most the mean number of instances per class.""" - np.random.seed(seed) - x = dataset.data - y = dataset.targets - num_to_sample = int(np.bincount(y.flatten()).mean()) - all_sampled_inds = [] - for label in np.unique(y.flatten()): - inds = np.where(y == label)[0] - sampled_inds = np.unique(np.random.choice(inds, num_to_sample)) - all_sampled_inds.append(sampled_inds) - ind = np.concatenate(all_sampled_inds) - x_sampled = x[ind] - y_sampled = y[ind] - dataset.data = x_sampled - dataset.targets = y_sampled - - -def fetch_emnist_dataset( - split: str, - train: bool, - sample_to_balance: bool = False, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, -) -> EMNIST: - """Fetch the EMNIST dataset.""" - if transform is None: - transform = Compose([Transpose(), ToTensor()]) - - dataset = EMNIST( - root=DATA_DIRNAME, - split="byclass", - train=train, - download=False, - transform=transform, - target_transform=target_transform, - ) - - if sample_to_balance and split == "byclass": - _sample_to_balance(dataset) - - return dataset - - -def fetch_emnist_data_loader( - splits: List[str], - sample_to_balance: bool = False, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - batch_size: int = 128, - shuffle: bool = False, - num_workers: int = 0, - cuda: bool = True, -) -> Dict[DataLoader]: - """Fetches the EMNIST dataset and return a PyTorch DataLoader. - - Args: - splits (List[str]): One or both of the dataset splits "train" and "val". - sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected. - Defaults to False. - transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a - transformed version. E.g, transforms.RandomCrop. Defaults to None. - target_transform (Optional[Callable]): A function/transform that takes in the target and transforms - it. - Defaults to None. - batch_size (int): How many samples per batch to load. Defaults to 128. - shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False. - num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be - loaded in the main process. Defaults to 0. - cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning - them. Defaults to True. - - Returns: - Dict: A dict containing PyTorch DataLoader(s) with emnist characters. - - """ - data_loaders = {} - - for split in ["train", "val"]: - if split in splits: - - if split == "train": - train = True - else: - train = False - - dataset = fetch_emnist_dataset( - split, train, sample_to_balance, transform, target_transform - ) - - data_loader = DataLoader( - dataset=dataset, - batch_size=batch_size, - shuffle=shuffle, - num_workers=num_workers, - pin_memory=cuda, - ) - - data_loaders[split] = data_loader - - return data_loaders +class EmnistDataLoader: + """Class for Emnist DataLoaders.""" + + def __init__( + self, + splits: List[str], + sample_to_balance: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + batch_size: int = 128, + shuffle: bool = False, + num_workers: int = 0, + cuda: bool = True, + seed: int = 4711, + ) -> None: + """Fetches DataLoaders. + + Args: + splits (List[str]): One or both of the dataset splits "train" and "val". + sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected. + Defaults to False. + subsample_fraction (float): The fraction of the dataset will be loaded. If None or 0 the entire + dataset will be loaded. + transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a + transformed version. E.g, transforms.RandomCrop. Defaults to None. + target_transform (Optional[Callable]): A function/transform that takes in the target and + transforms it. Defaults to None. + batch_size (int): How many samples per batch to load. Defaults to 128. + shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False. + num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be + loaded in the main process. Defaults to 0. + cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning + them. Defaults to True. + seed (int): Seed for sampling. + + """ + self.splits = splits + self.sample_to_balance = sample_to_balance + if subsample_fraction is not None: + assert ( + 0.0 < subsample_fraction < 1.0 + ), " The subsample fraction must be in (0, 1)." + self.subsample_fraction = subsample_fraction + self.transform = transform + self.target_transform = target_transform + self.batch_size = batch_size + self.shuffle = shuffle + self.num_workers = num_workers + self.cuda = cuda + self._data_loaders = self._fetch_emnist_data_loaders() + + @property + def __name__(self) -> str: + """Returns the name of the dataset.""" + return "EMNIST" + + def __call__(self, split: str) -> Optional[DataLoader]: + """Returns the `split` DataLoader. + + Args: + split (str): The dataset split, i.e. train or val. + + Returns: + type: A PyTorch DataLoader. + + Raises: + ValueError: If the split does not exist. + + """ + try: + return self._data_loaders[split] + except KeyError: + raise ValueError(f"Split {split} does not exist.") + + def _sample_to_balance(self, dataset: type = EMNIST) -> EMNIST: + """Because the dataset is not balanced, we take at most the mean number of instances per class.""" + np.random.seed(self.seed) + x = dataset.data + y = dataset.targets + num_to_sample = int(np.bincount(y.flatten()).mean()) + all_sampled_indices = [] + for label in np.unique(y.flatten()): + inds = np.where(y == label)[0] + sampled_indices = np.unique(np.random.choice(inds, num_to_sample)) + all_sampled_indices.append(sampled_indices) + indices = np.concatenate(all_sampled_indices) + x_sampled = x[indices] + y_sampled = y[indices] + dataset.data = x_sampled + dataset.targets = y_sampled + + return dataset + + def _subsample(self, dataset: type = EMNIST) -> EMNIST: + """Subsamples the dataset to the specified fraction.""" + x = dataset.data + y = dataset.targets + num_samples = int(x.shape[0] * self.subsample_fraction) + x_sampled = x[:num_samples] + y_sampled = y[:num_samples] + dataset.data = x_sampled + dataset.targets = y_sampled + + return dataset + + def _fetch_emnist_dataset(self, train: bool) -> EMNIST: + """Fetch the EMNIST dataset.""" + if self.transform is None: + transform = Compose([Transpose(), ToTensor()]) + + dataset = EMNIST( + root=DATA_DIRNAME, + split="byclass", + train=train, + download=False, + transform=transform, + target_transform=self.target_transform, + ) + + if self.sample_to_balance: + dataset = self._sample_to_balance(dataset) + + if self.subsample_fraction is not None: + dataset = self._subsample(dataset) + + return dataset + + def _fetch_emnist_data_loaders(self) -> Dict[str, DataLoader]: + """Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders.""" + data_loaders = {} + + for split in ["train", "val"]: + if split in self.splits: + + if split == "train": + train = True + else: + train = False + + dataset = self._fetch_emnist_dataset(train) + + data_loader = DataLoader( + dataset=dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + pin_memory=self.cuda, + ) + + data_loaders[split] = data_loader + + return data_loaders diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index 736af7b..0cc531a 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -10,7 +10,6 @@ import torch from torch import nn from torchsummary import summary -from text_recognizer.dataset.data_loader import fetch_data_loader WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights" @@ -22,6 +21,7 @@ class Model(ABC): self, network_fn: Callable, network_args: Dict, + data_loader: Optional[Callable] = None, data_loader_args: Optional[Dict] = None, metrics: Optional[Dict] = None, criterion: Optional[Callable] = None, @@ -32,12 +32,13 @@ class Model(ABC): lr_scheduler_args: Optional[Dict] = None, device: Optional[str] = None, ) -> None: - """Base class, to be inherited by predictors for specific type of data. + """Base class, to be inherited by model for specific type of data. Args: network_fn (Callable): The PyTorch network. network_args (Dict): Arguments for the network. - data_loader_args (Optional[Dict]): Arguments for the data loader. + data_loader (Optional[Callable]): A function that fetches train and val DataLoader. + data_loader_args (Optional[Dict]): Arguments for the DataLoader. metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. criterion (Optional[Callable]): The criterion to evaulate the preformance of the network. Defaults to None. @@ -53,8 +54,8 @@ class Model(ABC): # Fetch data loaders. if data_loader_args is not None: - self._data_loaders = fetch_data_loader(**data_loader_args) - dataset_name = self._data_loaders.items()[0].dataset.__name__ + self._data_loaders = data_loader(**data_loader_args) + dataset_name = self._data_loaders.__name__ else: dataset_name = "" self._data_loaders = None @@ -210,7 +211,6 @@ class Model(ABC): logger.debug( f"Found a new best {val_metric}. Saving best checkpoint and weights." ) - self.save_weights() shutil.copyfile(filepath, str(path / "best.pt")) def load_weights(self) -> None: diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 1570344..fd69bf2 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -32,17 +32,21 @@ class CharacterModel(Model): super().__init__( network_fn, - data_loader_args, network_args, + data_loader_args, metrics, criterion, + criterion_args, optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, device, ) self.emnist_mapping = self.mapping() self.eval() - def mapping(self) -> Dict: + def mapping(self) -> Dict[int, str]: """Mapping between integers and classes.""" mapping = load_emnist_mapping() return mapping diff --git a/src/text_recognizer/models/util.py b/src/text_recognizer/models/metrics.py index 905fe7b..e2a30a9 100644 --- a/src/text_recognizer/models/util.py +++ b/src/text_recognizer/models/metrics.py @@ -4,7 +4,7 @@ import torch def accuracy(outputs: torch.Tensor, labels: torch.Tensro) -> float: - """Short summary. + """Computes the accuracy. Args: outputs (torch.Tensor): The output from the network. diff --git a/src/training/gpu_manager.py b/src/training/gpu_manager.py new file mode 100644 index 0000000..ce1b3dd --- /dev/null +++ b/src/training/gpu_manager.py @@ -0,0 +1,62 @@ +"""GPUManager class.""" +import os +import time +from typing import Optional + +import gpustat +from loguru import logger +import numpy as np +from redlock import Redlock + + +GPU_LOCK_TIMEOUT = 5000 # ms + + +class GPUManager: + """Class for allocating GPUs.""" + + def __init__(self, verbose: bool = False) -> None: + """Initializes Redlock manager.""" + self.lock_manager = Redlock([{"host": "localhost", "port": 6379, "db": 0}]) + self.verbose = verbose + + def get_free_gpu(self) -> int: + """Gets a free GPU. + + If some GPUs are available, try reserving one by checking out an exclusive redis lock. + If none available or can not get lock, sleep and check again. + + Returns: + int: The gpu index. + + """ + while True: + gpu_index = self._get_free_gpu() + if gpu_index is not None: + return gpu_index + + if self.verbose: + logger.debug(f"pid {os.getpid()} sleeping") + time.sleep(GPU_LOCK_TIMEOUT / 1000) + + def _get_free_gpu(self) -> Optional[int]: + """Fetches an available GPU index.""" + try: + available_gpu_indices = [ + gpu.index + for gpu in gpustat.GPUStatCollection.new_query() + if gpu.memory_used < 0.5 * gpu.memory_total + ] + except Exception as e: + logger.debug(f"Got the following exception: {e}") + return None + + if available_gpu_indices: + gpu_index = np.random.choice(available_gpu_indices) + if self.verbose: + logger.debug(f"pid {os.getpid()} picking gpu {gpu_index}") + if self.lock_manager.lock(f"gpu_{gpu_index}", GPU_LOCK_TIMEOUT): + return int(gpu_index) + if self.verbose: + logger.debug(f"pid {os.getpid()} could not get lock.") + return None diff --git a/src/training/prepare_experiments.py b/src/training/prepare_experiments.py new file mode 100644 index 0000000..1ab8f00 --- /dev/null +++ b/src/training/prepare_experiments.py @@ -0,0 +1,35 @@ +"""Run a experiment from a config file.""" +import json + +import click +from loguru import logger +import yaml + + +def run_experiment(experiment_filename: str) -> None: + """Run experiment from file.""" + with open(experiment_filename) as f: + experiments_config = yaml.safe_load(f) + num_experiments = len(experiments_config["experiments"]) + for index in range(num_experiments): + experiment_config = experiments_config["experiments"][index] + experiment_config["experiment_group"] = experiments_config["experiment_group"] + print( + f"python training/run_experiment.py --gpu=-1 '{json.dumps(experiment_config)}'" + ) + + +@click.command() +@click.option( + "--experiments_filename", + required=True, + type=str, + help="Filename of Yaml file of experiments to run.", +) +def main(experiment_filename: str) -> None: + """Parse command-line arguments and run experiments from provided file.""" + run_experiment(experiment_filename) + + +if __name__ == "__main__": + main() diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index 8033f47..8296e59 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -1 +1,75 @@ """Script to run experiments.""" +import importlib +import os +from typing import Dict + +import click +import torch +from training.train import Trainer + + +def run_experiment( + experiment_config: Dict, save_weights: bool, gpu_index: int, use_wandb: bool = False +) -> None: + """Short summary.""" + # Import the data loader module and arguments. + datasets_module = importlib.import_module("text_recognizer.datasets") + data_loader_ = getattr(datasets_module, experiment_config["dataloader"]) + data_loader_args = experiment_config.get("data_loader_args", {}) + + # Import the model module and model arguments. + models_module = importlib.import_module("text_recognizer.models") + model_class_ = getattr(models_module, experiment_config["model"]) + + # Import metric. + metric_fn_ = getattr(models_module, experiment_config["metric"]) + + # Import network module and arguments. + network_module = importlib.import_module("text_recognizer.networks") + network_fn_ = getattr(network_module, experiment_config["network"]) + network_args = experiment_config.get("network_args", {}) + + # Criterion + criterion_ = getattr(torch.nn, experiment_config["criterion"]) + criterion_args = experiment_config.get("criterion_args", {}) + + # Optimizer + optimizer_ = getattr(torch.optim, experiment_config["optimizer"]) + optimizer_args = experiment_config.get("optimizer_args", {}) + + # Learning rate scheduler + lr_scheduler_ = None + lr_scheduler_args = None + if experiment_config["lr_scheduler"] is not None: + lr_scheduler_ = getattr( + torch.optim.lr_scheduler, experiment_config["lr_scheduler"] + ) + lr_scheduler_args = experiment_config.get("lr_scheduler_args", {}) + + # Device + # TODO fix gpu manager + device = None + + model = model_class_( + network_fn=network_fn_, + network_args=network_args, + data_loader=data_loader_, + data_loader_args=data_loader_args, + metrics=metric_fn_, + criterion=criterion_, + criterion_args=criterion_args, + optimizer=optimizer_, + optimizer_args=optimizer_args, + lr_scheduler=lr_scheduler_, + lr_scheduler_args=lr_scheduler_args, + device=device, + ) + + # TODO: Fix checkpoint path and wandb + trainer = Trainer( + model=model, + epochs=experiment_config["epochs"], + val_metric=experiment_config["metric"], + ) + + trainer.fit() diff --git a/src/training/train.py b/src/training/train.py index 783de02..4a452b6 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -9,6 +9,8 @@ import numpy as np import torch from tqdm import tqdm, trange from training.util import RunningAverage +import wandb + torch.backends.cudnn.benchmark = True np.random.seed(4711) @@ -30,6 +32,7 @@ class Trainer: epochs: int, val_metric: str = "accuracy", checkpoint_path: Optional[Path] = None, + use_wandb: Optional[bool] = False, ) -> None: """Initialization of the Trainer. @@ -38,6 +41,7 @@ class Trainer: epochs (int): Number of epochs to train. val_metric (str): The validation metric to evaluate the model on. Defaults to "accuracy". checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None. + use_wandb (Optional[bool]): Sync training to wandb. """ self.model = model @@ -48,13 +52,16 @@ class Trainer: if self.checkpoint_path is not None: self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) + if use_wandb: + # TODO implement wandb logging. + pass + self.val_metric = val_metric self.best_val_metric = 0.0 logger.add(self.model.name + "_{time}.log") def train(self) -> None: """Training loop.""" - # TODO add summary # Set model to traning mode. self.model.train() @@ -93,10 +100,6 @@ class Trainer: # Perform updates using calculated gradients. self.model.optimizer.step() - # Update the learning rate scheduler. - if self.model.lr_scheduler is not None: - self.model.lr_scheduler.step() - # Compute metrics. loss_avg.update(loss.item()) output = output.data.cpu() @@ -174,8 +177,8 @@ class Trainer: return metrics_mean - def run(self) -> None: - """Training and evaluation loop.""" + def fit(self) -> None: + """Runs the training and evaluation loop.""" # Create new experiment. EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True) experiment = datetime.now().strftime("%m%d_%H%M%S") |