summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/notebooks/Untitled.ipynb177
-rw-r--r--src/text_recognizer/datasets/__init__.py2
-rw-r--r--src/text_recognizer/datasets/data_loader.py15
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py258
-rw-r--r--src/text_recognizer/models/base.py12
-rw-r--r--src/text_recognizer/models/character_model.py8
-rw-r--r--src/text_recognizer/models/metrics.py (renamed from src/text_recognizer/models/util.py)2
-rw-r--r--src/training/gpu_manager.py62
-rw-r--r--src/training/prepare_experiments.py35
-rw-r--r--src/training/run_experiment.py74
-rw-r--r--src/training/train.py17
11 files changed, 527 insertions, 135 deletions
diff --git a/src/notebooks/Untitled.ipynb b/src/notebooks/Untitled.ipynb
new file mode 100644
index 0000000..1cb7acb
--- /dev/null
+++ b/src/notebooks/Untitled.ipynb
@@ -0,0 +1,177 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "getattr(torch.optim.lr_scheduler, \"StepLR\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "loss = getattr(torch.nn, \"L1Loss\")()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "input = torch.randn(3, 5, requires_grad=True)\n",
+ "target = torch.randn(3, 5)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "output = loss(input, target)\n",
+ "output.backward()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "output"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "s = 1.\n",
+ "if s is not None:\n",
+ " assert 0.0 < s < 1.0"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class A:\n",
+ " @property\n",
+ " def __name__(self):\n",
+ " return \"adafa\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "a = A()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "a.__name__"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from training.gpu_manager import GPUManager"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "gpu_manager = GPUManager(True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2020-07-05 21:16:55.100 | DEBUG | training.gpu_manager:_get_free_gpu:55 - pid 29777 picking gpu 0\n",
+ "2020-07-05 21:16:55.704 | DEBUG | training.gpu_manager:_get_free_gpu:59 - pid 29777 could not get lock.\n",
+ "2020-07-05 21:16:55.705 | DEBUG | training.gpu_manager:get_free_gpu:37 - pid 29777 sleeping\n",
+ "2020-07-05 21:17:00.722 | DEBUG | training.gpu_manager:_get_free_gpu:55 - pid 29777 picking gpu 0\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "0"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "gpu_manager.get_free_gpu()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py
index 929cfb9..aec5bf9 100644
--- a/src/text_recognizer/datasets/__init__.py
+++ b/src/text_recognizer/datasets/__init__.py
@@ -1,2 +1,2 @@
"""Dataset modules."""
-from .data_loader import fetch_data_loader
+from .emnist_dataset import EmnistDataLoader
diff --git a/src/text_recognizer/datasets/data_loader.py b/src/text_recognizer/datasets/data_loader.py
deleted file mode 100644
index fd55934..0000000
--- a/src/text_recognizer/datasets/data_loader.py
+++ /dev/null
@@ -1,15 +0,0 @@
-"""Data loader collection."""
-
-from typing import Dict
-
-from torch.utils.data import DataLoader
-
-from text_recognizer.datasets.emnist_dataset import fetch_emnist_data_loader
-
-
-def fetch_data_loader(data_loader_args: Dict) -> DataLoader:
- """Fetches the specified PyTorch data loader."""
- if data_loader_args.pop("name").lower() == "emnist":
- return fetch_emnist_data_loader(data_loader_args)
- else:
- raise NotImplementedError
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index f9c8ffa..a17d7a9 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -24,7 +24,7 @@ class Transpose:
return np.array(image).swapaxes(0, 1)
-def save_emnist_essentials(emnsit_dataset: EMNIST) -> None:
+def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None:
"""Extract and saves EMNIST essentials."""
labels = emnsit_dataset.classes
labels.sort()
@@ -45,111 +45,163 @@ def download_emnist() -> None:
save_emnist_essentials(dataset)
-def load_emnist_mapping() -> Dict:
+def load_emnist_mapping() -> Dict[int, str]:
"""Load the EMNIST mapping."""
with open(str(ESSENTIALS_FILENAME)) as f:
essentials = json.load(f)
return dict(essentials["mapping"])
-def _sample_to_balance(dataset: EMNIST, seed: int = 4711) -> None:
- """Because the dataset is not balanced, we take at most the mean number of instances per class."""
- np.random.seed(seed)
- x = dataset.data
- y = dataset.targets
- num_to_sample = int(np.bincount(y.flatten()).mean())
- all_sampled_inds = []
- for label in np.unique(y.flatten()):
- inds = np.where(y == label)[0]
- sampled_inds = np.unique(np.random.choice(inds, num_to_sample))
- all_sampled_inds.append(sampled_inds)
- ind = np.concatenate(all_sampled_inds)
- x_sampled = x[ind]
- y_sampled = y[ind]
- dataset.data = x_sampled
- dataset.targets = y_sampled
-
-
-def fetch_emnist_dataset(
- split: str,
- train: bool,
- sample_to_balance: bool = False,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
-) -> EMNIST:
- """Fetch the EMNIST dataset."""
- if transform is None:
- transform = Compose([Transpose(), ToTensor()])
-
- dataset = EMNIST(
- root=DATA_DIRNAME,
- split="byclass",
- train=train,
- download=False,
- transform=transform,
- target_transform=target_transform,
- )
-
- if sample_to_balance and split == "byclass":
- _sample_to_balance(dataset)
-
- return dataset
-
-
-def fetch_emnist_data_loader(
- splits: List[str],
- sample_to_balance: bool = False,
- transform: Optional[Callable] = None,
- target_transform: Optional[Callable] = None,
- batch_size: int = 128,
- shuffle: bool = False,
- num_workers: int = 0,
- cuda: bool = True,
-) -> Dict[DataLoader]:
- """Fetches the EMNIST dataset and return a PyTorch DataLoader.
-
- Args:
- splits (List[str]): One or both of the dataset splits "train" and "val".
- sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected.
- Defaults to False.
- transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a
- transformed version. E.g, transforms.RandomCrop. Defaults to None.
- target_transform (Optional[Callable]): A function/transform that takes in the target and transforms
- it.
- Defaults to None.
- batch_size (int): How many samples per batch to load. Defaults to 128.
- shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False.
- num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be
- loaded in the main process. Defaults to 0.
- cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning
- them. Defaults to True.
-
- Returns:
- Dict: A dict containing PyTorch DataLoader(s) with emnist characters.
-
- """
- data_loaders = {}
-
- for split in ["train", "val"]:
- if split in splits:
-
- if split == "train":
- train = True
- else:
- train = False
-
- dataset = fetch_emnist_dataset(
- split, train, sample_to_balance, transform, target_transform
- )
-
- data_loader = DataLoader(
- dataset=dataset,
- batch_size=batch_size,
- shuffle=shuffle,
- num_workers=num_workers,
- pin_memory=cuda,
- )
-
- data_loaders[split] = data_loader
-
- return data_loaders
+class EmnistDataLoader:
+ """Class for Emnist DataLoaders."""
+
+ def __init__(
+ self,
+ splits: List[str],
+ sample_to_balance: bool = False,
+ subsample_fraction: float = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ batch_size: int = 128,
+ shuffle: bool = False,
+ num_workers: int = 0,
+ cuda: bool = True,
+ seed: int = 4711,
+ ) -> None:
+ """Fetches DataLoaders.
+
+ Args:
+ splits (List[str]): One or both of the dataset splits "train" and "val".
+ sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected.
+ Defaults to False.
+ subsample_fraction (float): The fraction of the dataset will be loaded. If None or 0 the entire
+ dataset will be loaded.
+ transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a
+ transformed version. E.g, transforms.RandomCrop. Defaults to None.
+ target_transform (Optional[Callable]): A function/transform that takes in the target and
+ transforms it. Defaults to None.
+ batch_size (int): How many samples per batch to load. Defaults to 128.
+ shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False.
+ num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be
+ loaded in the main process. Defaults to 0.
+ cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning
+ them. Defaults to True.
+ seed (int): Seed for sampling.
+
+ """
+ self.splits = splits
+ self.sample_to_balance = sample_to_balance
+ if subsample_fraction is not None:
+ assert (
+ 0.0 < subsample_fraction < 1.0
+ ), " The subsample fraction must be in (0, 1)."
+ self.subsample_fraction = subsample_fraction
+ self.transform = transform
+ self.target_transform = target_transform
+ self.batch_size = batch_size
+ self.shuffle = shuffle
+ self.num_workers = num_workers
+ self.cuda = cuda
+ self._data_loaders = self._fetch_emnist_data_loaders()
+
+ @property
+ def __name__(self) -> str:
+ """Returns the name of the dataset."""
+ return "EMNIST"
+
+ def __call__(self, split: str) -> Optional[DataLoader]:
+ """Returns the `split` DataLoader.
+
+ Args:
+ split (str): The dataset split, i.e. train or val.
+
+ Returns:
+ type: A PyTorch DataLoader.
+
+ Raises:
+ ValueError: If the split does not exist.
+
+ """
+ try:
+ return self._data_loaders[split]
+ except KeyError:
+ raise ValueError(f"Split {split} does not exist.")
+
+ def _sample_to_balance(self, dataset: type = EMNIST) -> EMNIST:
+ """Because the dataset is not balanced, we take at most the mean number of instances per class."""
+ np.random.seed(self.seed)
+ x = dataset.data
+ y = dataset.targets
+ num_to_sample = int(np.bincount(y.flatten()).mean())
+ all_sampled_indices = []
+ for label in np.unique(y.flatten()):
+ inds = np.where(y == label)[0]
+ sampled_indices = np.unique(np.random.choice(inds, num_to_sample))
+ all_sampled_indices.append(sampled_indices)
+ indices = np.concatenate(all_sampled_indices)
+ x_sampled = x[indices]
+ y_sampled = y[indices]
+ dataset.data = x_sampled
+ dataset.targets = y_sampled
+
+ return dataset
+
+ def _subsample(self, dataset: type = EMNIST) -> EMNIST:
+ """Subsamples the dataset to the specified fraction."""
+ x = dataset.data
+ y = dataset.targets
+ num_samples = int(x.shape[0] * self.subsample_fraction)
+ x_sampled = x[:num_samples]
+ y_sampled = y[:num_samples]
+ dataset.data = x_sampled
+ dataset.targets = y_sampled
+
+ return dataset
+
+ def _fetch_emnist_dataset(self, train: bool) -> EMNIST:
+ """Fetch the EMNIST dataset."""
+ if self.transform is None:
+ transform = Compose([Transpose(), ToTensor()])
+
+ dataset = EMNIST(
+ root=DATA_DIRNAME,
+ split="byclass",
+ train=train,
+ download=False,
+ transform=transform,
+ target_transform=self.target_transform,
+ )
+
+ if self.sample_to_balance:
+ dataset = self._sample_to_balance(dataset)
+
+ if self.subsample_fraction is not None:
+ dataset = self._subsample(dataset)
+
+ return dataset
+
+ def _fetch_emnist_data_loaders(self) -> Dict[str, DataLoader]:
+ """Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders."""
+ data_loaders = {}
+
+ for split in ["train", "val"]:
+ if split in self.splits:
+
+ if split == "train":
+ train = True
+ else:
+ train = False
+
+ dataset = self._fetch_emnist_dataset(train)
+
+ data_loader = DataLoader(
+ dataset=dataset,
+ batch_size=self.batch_size,
+ shuffle=self.shuffle,
+ num_workers=self.num_workers,
+ pin_memory=self.cuda,
+ )
+
+ data_loaders[split] = data_loader
+
+ return data_loaders
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index 736af7b..0cc531a 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -10,7 +10,6 @@ import torch
from torch import nn
from torchsummary import summary
-from text_recognizer.dataset.data_loader import fetch_data_loader
WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights"
@@ -22,6 +21,7 @@ class Model(ABC):
self,
network_fn: Callable,
network_args: Dict,
+ data_loader: Optional[Callable] = None,
data_loader_args: Optional[Dict] = None,
metrics: Optional[Dict] = None,
criterion: Optional[Callable] = None,
@@ -32,12 +32,13 @@ class Model(ABC):
lr_scheduler_args: Optional[Dict] = None,
device: Optional[str] = None,
) -> None:
- """Base class, to be inherited by predictors for specific type of data.
+ """Base class, to be inherited by model for specific type of data.
Args:
network_fn (Callable): The PyTorch network.
network_args (Dict): Arguments for the network.
- data_loader_args (Optional[Dict]): Arguments for the data loader.
+ data_loader (Optional[Callable]): A function that fetches train and val DataLoader.
+ data_loader_args (Optional[Dict]): Arguments for the DataLoader.
metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None.
criterion (Optional[Callable]): The criterion to evaulate the preformance of the network.
Defaults to None.
@@ -53,8 +54,8 @@ class Model(ABC):
# Fetch data loaders.
if data_loader_args is not None:
- self._data_loaders = fetch_data_loader(**data_loader_args)
- dataset_name = self._data_loaders.items()[0].dataset.__name__
+ self._data_loaders = data_loader(**data_loader_args)
+ dataset_name = self._data_loaders.__name__
else:
dataset_name = ""
self._data_loaders = None
@@ -210,7 +211,6 @@ class Model(ABC):
logger.debug(
f"Found a new best {val_metric}. Saving best checkpoint and weights."
)
- self.save_weights()
shutil.copyfile(filepath, str(path / "best.pt"))
def load_weights(self) -> None:
diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py
index 1570344..fd69bf2 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/src/text_recognizer/models/character_model.py
@@ -32,17 +32,21 @@ class CharacterModel(Model):
super().__init__(
network_fn,
- data_loader_args,
network_args,
+ data_loader_args,
metrics,
criterion,
+ criterion_args,
optimizer,
+ optimizer_args,
+ lr_scheduler,
+ lr_scheduler_args,
device,
)
self.emnist_mapping = self.mapping()
self.eval()
- def mapping(self) -> Dict:
+ def mapping(self) -> Dict[int, str]:
"""Mapping between integers and classes."""
mapping = load_emnist_mapping()
return mapping
diff --git a/src/text_recognizer/models/util.py b/src/text_recognizer/models/metrics.py
index 905fe7b..e2a30a9 100644
--- a/src/text_recognizer/models/util.py
+++ b/src/text_recognizer/models/metrics.py
@@ -4,7 +4,7 @@ import torch
def accuracy(outputs: torch.Tensor, labels: torch.Tensro) -> float:
- """Short summary.
+ """Computes the accuracy.
Args:
outputs (torch.Tensor): The output from the network.
diff --git a/src/training/gpu_manager.py b/src/training/gpu_manager.py
new file mode 100644
index 0000000..ce1b3dd
--- /dev/null
+++ b/src/training/gpu_manager.py
@@ -0,0 +1,62 @@
+"""GPUManager class."""
+import os
+import time
+from typing import Optional
+
+import gpustat
+from loguru import logger
+import numpy as np
+from redlock import Redlock
+
+
+GPU_LOCK_TIMEOUT = 5000 # ms
+
+
+class GPUManager:
+ """Class for allocating GPUs."""
+
+ def __init__(self, verbose: bool = False) -> None:
+ """Initializes Redlock manager."""
+ self.lock_manager = Redlock([{"host": "localhost", "port": 6379, "db": 0}])
+ self.verbose = verbose
+
+ def get_free_gpu(self) -> int:
+ """Gets a free GPU.
+
+ If some GPUs are available, try reserving one by checking out an exclusive redis lock.
+ If none available or can not get lock, sleep and check again.
+
+ Returns:
+ int: The gpu index.
+
+ """
+ while True:
+ gpu_index = self._get_free_gpu()
+ if gpu_index is not None:
+ return gpu_index
+
+ if self.verbose:
+ logger.debug(f"pid {os.getpid()} sleeping")
+ time.sleep(GPU_LOCK_TIMEOUT / 1000)
+
+ def _get_free_gpu(self) -> Optional[int]:
+ """Fetches an available GPU index."""
+ try:
+ available_gpu_indices = [
+ gpu.index
+ for gpu in gpustat.GPUStatCollection.new_query()
+ if gpu.memory_used < 0.5 * gpu.memory_total
+ ]
+ except Exception as e:
+ logger.debug(f"Got the following exception: {e}")
+ return None
+
+ if available_gpu_indices:
+ gpu_index = np.random.choice(available_gpu_indices)
+ if self.verbose:
+ logger.debug(f"pid {os.getpid()} picking gpu {gpu_index}")
+ if self.lock_manager.lock(f"gpu_{gpu_index}", GPU_LOCK_TIMEOUT):
+ return int(gpu_index)
+ if self.verbose:
+ logger.debug(f"pid {os.getpid()} could not get lock.")
+ return None
diff --git a/src/training/prepare_experiments.py b/src/training/prepare_experiments.py
new file mode 100644
index 0000000..1ab8f00
--- /dev/null
+++ b/src/training/prepare_experiments.py
@@ -0,0 +1,35 @@
+"""Run a experiment from a config file."""
+import json
+
+import click
+from loguru import logger
+import yaml
+
+
+def run_experiment(experiment_filename: str) -> None:
+ """Run experiment from file."""
+ with open(experiment_filename) as f:
+ experiments_config = yaml.safe_load(f)
+ num_experiments = len(experiments_config["experiments"])
+ for index in range(num_experiments):
+ experiment_config = experiments_config["experiments"][index]
+ experiment_config["experiment_group"] = experiments_config["experiment_group"]
+ print(
+ f"python training/run_experiment.py --gpu=-1 '{json.dumps(experiment_config)}'"
+ )
+
+
+@click.command()
+@click.option(
+ "--experiments_filename",
+ required=True,
+ type=str,
+ help="Filename of Yaml file of experiments to run.",
+)
+def main(experiment_filename: str) -> None:
+ """Parse command-line arguments and run experiments from provided file."""
+ run_experiment(experiment_filename)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index 8033f47..8296e59 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -1 +1,75 @@
"""Script to run experiments."""
+import importlib
+import os
+from typing import Dict
+
+import click
+import torch
+from training.train import Trainer
+
+
+def run_experiment(
+ experiment_config: Dict, save_weights: bool, gpu_index: int, use_wandb: bool = False
+) -> None:
+ """Short summary."""
+ # Import the data loader module and arguments.
+ datasets_module = importlib.import_module("text_recognizer.datasets")
+ data_loader_ = getattr(datasets_module, experiment_config["dataloader"])
+ data_loader_args = experiment_config.get("data_loader_args", {})
+
+ # Import the model module and model arguments.
+ models_module = importlib.import_module("text_recognizer.models")
+ model_class_ = getattr(models_module, experiment_config["model"])
+
+ # Import metric.
+ metric_fn_ = getattr(models_module, experiment_config["metric"])
+
+ # Import network module and arguments.
+ network_module = importlib.import_module("text_recognizer.networks")
+ network_fn_ = getattr(network_module, experiment_config["network"])
+ network_args = experiment_config.get("network_args", {})
+
+ # Criterion
+ criterion_ = getattr(torch.nn, experiment_config["criterion"])
+ criterion_args = experiment_config.get("criterion_args", {})
+
+ # Optimizer
+ optimizer_ = getattr(torch.optim, experiment_config["optimizer"])
+ optimizer_args = experiment_config.get("optimizer_args", {})
+
+ # Learning rate scheduler
+ lr_scheduler_ = None
+ lr_scheduler_args = None
+ if experiment_config["lr_scheduler"] is not None:
+ lr_scheduler_ = getattr(
+ torch.optim.lr_scheduler, experiment_config["lr_scheduler"]
+ )
+ lr_scheduler_args = experiment_config.get("lr_scheduler_args", {})
+
+ # Device
+ # TODO fix gpu manager
+ device = None
+
+ model = model_class_(
+ network_fn=network_fn_,
+ network_args=network_args,
+ data_loader=data_loader_,
+ data_loader_args=data_loader_args,
+ metrics=metric_fn_,
+ criterion=criterion_,
+ criterion_args=criterion_args,
+ optimizer=optimizer_,
+ optimizer_args=optimizer_args,
+ lr_scheduler=lr_scheduler_,
+ lr_scheduler_args=lr_scheduler_args,
+ device=device,
+ )
+
+ # TODO: Fix checkpoint path and wandb
+ trainer = Trainer(
+ model=model,
+ epochs=experiment_config["epochs"],
+ val_metric=experiment_config["metric"],
+ )
+
+ trainer.fit()
diff --git a/src/training/train.py b/src/training/train.py
index 783de02..4a452b6 100644
--- a/src/training/train.py
+++ b/src/training/train.py
@@ -9,6 +9,8 @@ import numpy as np
import torch
from tqdm import tqdm, trange
from training.util import RunningAverage
+import wandb
+
torch.backends.cudnn.benchmark = True
np.random.seed(4711)
@@ -30,6 +32,7 @@ class Trainer:
epochs: int,
val_metric: str = "accuracy",
checkpoint_path: Optional[Path] = None,
+ use_wandb: Optional[bool] = False,
) -> None:
"""Initialization of the Trainer.
@@ -38,6 +41,7 @@ class Trainer:
epochs (int): Number of epochs to train.
val_metric (str): The validation metric to evaluate the model on. Defaults to "accuracy".
checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None.
+ use_wandb (Optional[bool]): Sync training to wandb.
"""
self.model = model
@@ -48,13 +52,16 @@ class Trainer:
if self.checkpoint_path is not None:
self.start_epoch = self.model.load_checkpoint(self.checkpoint_path)
+ if use_wandb:
+ # TODO implement wandb logging.
+ pass
+
self.val_metric = val_metric
self.best_val_metric = 0.0
logger.add(self.model.name + "_{time}.log")
def train(self) -> None:
"""Training loop."""
- # TODO add summary
# Set model to traning mode.
self.model.train()
@@ -93,10 +100,6 @@ class Trainer:
# Perform updates using calculated gradients.
self.model.optimizer.step()
- # Update the learning rate scheduler.
- if self.model.lr_scheduler is not None:
- self.model.lr_scheduler.step()
-
# Compute metrics.
loss_avg.update(loss.item())
output = output.data.cpu()
@@ -174,8 +177,8 @@ class Trainer:
return metrics_mean
- def run(self) -> None:
- """Training and evaluation loop."""
+ def fit(self) -> None:
+ """Runs the training and evaluation loop."""
# Create new experiment.
EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True)
experiment = datetime.now().strftime("%m%d_%H%M%S")