From 5a78fc2e33c28968a69d033cb10d638f4f63fed1 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Sun, 5 Jul 2020 22:27:08 +0200 Subject: Working on getting experiment loop. --- src/notebooks/Untitled.ipynb | 177 +++++++++++++++++ src/text_recognizer/datasets/__init__.py | 2 +- src/text_recognizer/datasets/data_loader.py | 15 -- src/text_recognizer/datasets/emnist_dataset.py | 258 +++++++++++++++---------- src/text_recognizer/models/base.py | 12 +- src/text_recognizer/models/character_model.py | 8 +- src/text_recognizer/models/metrics.py | 19 ++ src/text_recognizer/models/util.py | 19 -- src/training/gpu_manager.py | 62 ++++++ src/training/prepare_experiments.py | 35 ++++ src/training/run_experiment.py | 74 +++++++ src/training/train.py | 17 +- 12 files changed, 545 insertions(+), 153 deletions(-) create mode 100644 src/notebooks/Untitled.ipynb delete mode 100644 src/text_recognizer/datasets/data_loader.py create mode 100644 src/text_recognizer/models/metrics.py delete mode 100644 src/text_recognizer/models/util.py create mode 100644 src/training/gpu_manager.py create mode 100644 src/training/prepare_experiments.py (limited to 'src') diff --git a/src/notebooks/Untitled.ipynb b/src/notebooks/Untitled.ipynb new file mode 100644 index 0000000..1cb7acb --- /dev/null +++ b/src/notebooks/Untitled.ipynb @@ -0,0 +1,177 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "getattr(torch.optim.lr_scheduler, \"StepLR\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "loss = getattr(torch.nn, \"L1Loss\")()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "input = torch.randn(3, 5, requires_grad=True)\n", + "target = torch.randn(3, 5)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output = loss(input, target)\n", + "output.backward()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "s = 1.\n", + "if s is not None:\n", + " assert 0.0 < s < 1.0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class A:\n", + " @property\n", + " def __name__(self):\n", + " return \"adafa\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "a = A()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "a.__name__" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from training.gpu_manager import GPUManager" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "gpu_manager = GPUManager(True)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2020-07-05 21:16:55.100 | DEBUG | training.gpu_manager:_get_free_gpu:55 - pid 29777 picking gpu 0\n", + "2020-07-05 21:16:55.704 | DEBUG | training.gpu_manager:_get_free_gpu:59 - pid 29777 could not get lock.\n", + "2020-07-05 21:16:55.705 | DEBUG | training.gpu_manager:get_free_gpu:37 - pid 29777 sleeping\n", + "2020-07-05 21:17:00.722 | DEBUG | training.gpu_manager:_get_free_gpu:55 - pid 29777 picking gpu 0\n" + ] + }, + { + "data": { + "text/plain": [ + "0" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "gpu_manager.get_free_gpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.2" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py index 929cfb9..aec5bf9 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/src/text_recognizer/datasets/__init__.py @@ -1,2 +1,2 @@ """Dataset modules.""" -from .data_loader import fetch_data_loader +from .emnist_dataset import EmnistDataLoader diff --git a/src/text_recognizer/datasets/data_loader.py b/src/text_recognizer/datasets/data_loader.py deleted file mode 100644 index fd55934..0000000 --- a/src/text_recognizer/datasets/data_loader.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Data loader collection.""" - -from typing import Dict - -from torch.utils.data import DataLoader - -from text_recognizer.datasets.emnist_dataset import fetch_emnist_data_loader - - -def fetch_data_loader(data_loader_args: Dict) -> DataLoader: - """Fetches the specified PyTorch data loader.""" - if data_loader_args.pop("name").lower() == "emnist": - return fetch_emnist_data_loader(data_loader_args) - else: - raise NotImplementedError diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py index f9c8ffa..a17d7a9 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/src/text_recognizer/datasets/emnist_dataset.py @@ -24,7 +24,7 @@ class Transpose: return np.array(image).swapaxes(0, 1) -def save_emnist_essentials(emnsit_dataset: EMNIST) -> None: +def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None: """Extract and saves EMNIST essentials.""" labels = emnsit_dataset.classes labels.sort() @@ -45,111 +45,163 @@ def download_emnist() -> None: save_emnist_essentials(dataset) -def load_emnist_mapping() -> Dict: +def load_emnist_mapping() -> Dict[int, str]: """Load the EMNIST mapping.""" with open(str(ESSENTIALS_FILENAME)) as f: essentials = json.load(f) return dict(essentials["mapping"]) -def _sample_to_balance(dataset: EMNIST, seed: int = 4711) -> None: - """Because the dataset is not balanced, we take at most the mean number of instances per class.""" - np.random.seed(seed) - x = dataset.data - y = dataset.targets - num_to_sample = int(np.bincount(y.flatten()).mean()) - all_sampled_inds = [] - for label in np.unique(y.flatten()): - inds = np.where(y == label)[0] - sampled_inds = np.unique(np.random.choice(inds, num_to_sample)) - all_sampled_inds.append(sampled_inds) - ind = np.concatenate(all_sampled_inds) - x_sampled = x[ind] - y_sampled = y[ind] - dataset.data = x_sampled - dataset.targets = y_sampled - - -def fetch_emnist_dataset( - split: str, - train: bool, - sample_to_balance: bool = False, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, -) -> EMNIST: - """Fetch the EMNIST dataset.""" - if transform is None: - transform = Compose([Transpose(), ToTensor()]) - - dataset = EMNIST( - root=DATA_DIRNAME, - split="byclass", - train=train, - download=False, - transform=transform, - target_transform=target_transform, - ) - - if sample_to_balance and split == "byclass": - _sample_to_balance(dataset) - - return dataset - - -def fetch_emnist_data_loader( - splits: List[str], - sample_to_balance: bool = False, - transform: Optional[Callable] = None, - target_transform: Optional[Callable] = None, - batch_size: int = 128, - shuffle: bool = False, - num_workers: int = 0, - cuda: bool = True, -) -> Dict[DataLoader]: - """Fetches the EMNIST dataset and return a PyTorch DataLoader. - - Args: - splits (List[str]): One or both of the dataset splits "train" and "val". - sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected. - Defaults to False. - transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a - transformed version. E.g, transforms.RandomCrop. Defaults to None. - target_transform (Optional[Callable]): A function/transform that takes in the target and transforms - it. - Defaults to None. - batch_size (int): How many samples per batch to load. Defaults to 128. - shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False. - num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be - loaded in the main process. Defaults to 0. - cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning - them. Defaults to True. - - Returns: - Dict: A dict containing PyTorch DataLoader(s) with emnist characters. - - """ - data_loaders = {} - - for split in ["train", "val"]: - if split in splits: - - if split == "train": - train = True - else: - train = False - - dataset = fetch_emnist_dataset( - split, train, sample_to_balance, transform, target_transform - ) - - data_loader = DataLoader( - dataset=dataset, - batch_size=batch_size, - shuffle=shuffle, - num_workers=num_workers, - pin_memory=cuda, - ) - - data_loaders[split] = data_loader - - return data_loaders +class EmnistDataLoader: + """Class for Emnist DataLoaders.""" + + def __init__( + self, + splits: List[str], + sample_to_balance: bool = False, + subsample_fraction: float = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + batch_size: int = 128, + shuffle: bool = False, + num_workers: int = 0, + cuda: bool = True, + seed: int = 4711, + ) -> None: + """Fetches DataLoaders. + + Args: + splits (List[str]): One or both of the dataset splits "train" and "val". + sample_to_balance (bool): If true, resamples the unbalanced if the split "byclass" is selected. + Defaults to False. + subsample_fraction (float): The fraction of the dataset will be loaded. If None or 0 the entire + dataset will be loaded. + transform (Optional[Callable]): A function/transform that takes in an PIL image and returns a + transformed version. E.g, transforms.RandomCrop. Defaults to None. + target_transform (Optional[Callable]): A function/transform that takes in the target and + transforms it. Defaults to None. + batch_size (int): How many samples per batch to load. Defaults to 128. + shuffle (bool): Set to True to have the data reshuffled at every epoch. Defaults to False. + num_workers (int): How many subprocesses to use for data loading. 0 means that the data will be + loaded in the main process. Defaults to 0. + cuda (bool): If True, the data loader will copy Tensors into CUDA pinned memory before returning + them. Defaults to True. + seed (int): Seed for sampling. + + """ + self.splits = splits + self.sample_to_balance = sample_to_balance + if subsample_fraction is not None: + assert ( + 0.0 < subsample_fraction < 1.0 + ), " The subsample fraction must be in (0, 1)." + self.subsample_fraction = subsample_fraction + self.transform = transform + self.target_transform = target_transform + self.batch_size = batch_size + self.shuffle = shuffle + self.num_workers = num_workers + self.cuda = cuda + self._data_loaders = self._fetch_emnist_data_loaders() + + @property + def __name__(self) -> str: + """Returns the name of the dataset.""" + return "EMNIST" + + def __call__(self, split: str) -> Optional[DataLoader]: + """Returns the `split` DataLoader. + + Args: + split (str): The dataset split, i.e. train or val. + + Returns: + type: A PyTorch DataLoader. + + Raises: + ValueError: If the split does not exist. + + """ + try: + return self._data_loaders[split] + except KeyError: + raise ValueError(f"Split {split} does not exist.") + + def _sample_to_balance(self, dataset: type = EMNIST) -> EMNIST: + """Because the dataset is not balanced, we take at most the mean number of instances per class.""" + np.random.seed(self.seed) + x = dataset.data + y = dataset.targets + num_to_sample = int(np.bincount(y.flatten()).mean()) + all_sampled_indices = [] + for label in np.unique(y.flatten()): + inds = np.where(y == label)[0] + sampled_indices = np.unique(np.random.choice(inds, num_to_sample)) + all_sampled_indices.append(sampled_indices) + indices = np.concatenate(all_sampled_indices) + x_sampled = x[indices] + y_sampled = y[indices] + dataset.data = x_sampled + dataset.targets = y_sampled + + return dataset + + def _subsample(self, dataset: type = EMNIST) -> EMNIST: + """Subsamples the dataset to the specified fraction.""" + x = dataset.data + y = dataset.targets + num_samples = int(x.shape[0] * self.subsample_fraction) + x_sampled = x[:num_samples] + y_sampled = y[:num_samples] + dataset.data = x_sampled + dataset.targets = y_sampled + + return dataset + + def _fetch_emnist_dataset(self, train: bool) -> EMNIST: + """Fetch the EMNIST dataset.""" + if self.transform is None: + transform = Compose([Transpose(), ToTensor()]) + + dataset = EMNIST( + root=DATA_DIRNAME, + split="byclass", + train=train, + download=False, + transform=transform, + target_transform=self.target_transform, + ) + + if self.sample_to_balance: + dataset = self._sample_to_balance(dataset) + + if self.subsample_fraction is not None: + dataset = self._subsample(dataset) + + return dataset + + def _fetch_emnist_data_loaders(self) -> Dict[str, DataLoader]: + """Fetches the EMNIST dataset and return a Dict of PyTorch DataLoaders.""" + data_loaders = {} + + for split in ["train", "val"]: + if split in self.splits: + + if split == "train": + train = True + else: + train = False + + dataset = self._fetch_emnist_dataset(train) + + data_loader = DataLoader( + dataset=dataset, + batch_size=self.batch_size, + shuffle=self.shuffle, + num_workers=self.num_workers, + pin_memory=self.cuda, + ) + + data_loaders[split] = data_loader + + return data_loaders diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index 736af7b..0cc531a 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -10,7 +10,6 @@ import torch from torch import nn from torchsummary import summary -from text_recognizer.dataset.data_loader import fetch_data_loader WEIGHT_DIRNAME = Path(__file__).parents[1].resolve() / "weights" @@ -22,6 +21,7 @@ class Model(ABC): self, network_fn: Callable, network_args: Dict, + data_loader: Optional[Callable] = None, data_loader_args: Optional[Dict] = None, metrics: Optional[Dict] = None, criterion: Optional[Callable] = None, @@ -32,12 +32,13 @@ class Model(ABC): lr_scheduler_args: Optional[Dict] = None, device: Optional[str] = None, ) -> None: - """Base class, to be inherited by predictors for specific type of data. + """Base class, to be inherited by model for specific type of data. Args: network_fn (Callable): The PyTorch network. network_args (Dict): Arguments for the network. - data_loader_args (Optional[Dict]): Arguments for the data loader. + data_loader (Optional[Callable]): A function that fetches train and val DataLoader. + data_loader_args (Optional[Dict]): Arguments for the DataLoader. metrics (Optional[Dict]): Metrics to evaluate the performance with. Defaults to None. criterion (Optional[Callable]): The criterion to evaulate the preformance of the network. Defaults to None. @@ -53,8 +54,8 @@ class Model(ABC): # Fetch data loaders. if data_loader_args is not None: - self._data_loaders = fetch_data_loader(**data_loader_args) - dataset_name = self._data_loaders.items()[0].dataset.__name__ + self._data_loaders = data_loader(**data_loader_args) + dataset_name = self._data_loaders.__name__ else: dataset_name = "" self._data_loaders = None @@ -210,7 +211,6 @@ class Model(ABC): logger.debug( f"Found a new best {val_metric}. Saving best checkpoint and weights." ) - self.save_weights() shutil.copyfile(filepath, str(path / "best.pt")) def load_weights(self) -> None: diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 1570344..fd69bf2 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -32,17 +32,21 @@ class CharacterModel(Model): super().__init__( network_fn, - data_loader_args, network_args, + data_loader_args, metrics, criterion, + criterion_args, optimizer, + optimizer_args, + lr_scheduler, + lr_scheduler_args, device, ) self.emnist_mapping = self.mapping() self.eval() - def mapping(self) -> Dict: + def mapping(self) -> Dict[int, str]: """Mapping between integers and classes.""" mapping = load_emnist_mapping() return mapping diff --git a/src/text_recognizer/models/metrics.py b/src/text_recognizer/models/metrics.py new file mode 100644 index 0000000..e2a30a9 --- /dev/null +++ b/src/text_recognizer/models/metrics.py @@ -0,0 +1,19 @@ +"""Utility functions for models.""" + +import torch + + +def accuracy(outputs: torch.Tensor, labels: torch.Tensro) -> float: + """Computes the accuracy. + + Args: + outputs (torch.Tensor): The output from the network. + labels (torch.Tensor): Ground truth labels. + + Returns: + float: The accuracy for the batch. + + """ + _, predicted = torch.max(outputs.data, dim=1) + acc = (predicted == labels).sum().item() / labels.shape[0] + return acc diff --git a/src/text_recognizer/models/util.py b/src/text_recognizer/models/util.py deleted file mode 100644 index 905fe7b..0000000 --- a/src/text_recognizer/models/util.py +++ /dev/null @@ -1,19 +0,0 @@ -"""Utility functions for models.""" - -import torch - - -def accuracy(outputs: torch.Tensor, labels: torch.Tensro) -> float: - """Short summary. - - Args: - outputs (torch.Tensor): The output from the network. - labels (torch.Tensor): Ground truth labels. - - Returns: - float: The accuracy for the batch. - - """ - _, predicted = torch.max(outputs.data, dim=1) - acc = (predicted == labels).sum().item() / labels.shape[0] - return acc diff --git a/src/training/gpu_manager.py b/src/training/gpu_manager.py new file mode 100644 index 0000000..ce1b3dd --- /dev/null +++ b/src/training/gpu_manager.py @@ -0,0 +1,62 @@ +"""GPUManager class.""" +import os +import time +from typing import Optional + +import gpustat +from loguru import logger +import numpy as np +from redlock import Redlock + + +GPU_LOCK_TIMEOUT = 5000 # ms + + +class GPUManager: + """Class for allocating GPUs.""" + + def __init__(self, verbose: bool = False) -> None: + """Initializes Redlock manager.""" + self.lock_manager = Redlock([{"host": "localhost", "port": 6379, "db": 0}]) + self.verbose = verbose + + def get_free_gpu(self) -> int: + """Gets a free GPU. + + If some GPUs are available, try reserving one by checking out an exclusive redis lock. + If none available or can not get lock, sleep and check again. + + Returns: + int: The gpu index. + + """ + while True: + gpu_index = self._get_free_gpu() + if gpu_index is not None: + return gpu_index + + if self.verbose: + logger.debug(f"pid {os.getpid()} sleeping") + time.sleep(GPU_LOCK_TIMEOUT / 1000) + + def _get_free_gpu(self) -> Optional[int]: + """Fetches an available GPU index.""" + try: + available_gpu_indices = [ + gpu.index + for gpu in gpustat.GPUStatCollection.new_query() + if gpu.memory_used < 0.5 * gpu.memory_total + ] + except Exception as e: + logger.debug(f"Got the following exception: {e}") + return None + + if available_gpu_indices: + gpu_index = np.random.choice(available_gpu_indices) + if self.verbose: + logger.debug(f"pid {os.getpid()} picking gpu {gpu_index}") + if self.lock_manager.lock(f"gpu_{gpu_index}", GPU_LOCK_TIMEOUT): + return int(gpu_index) + if self.verbose: + logger.debug(f"pid {os.getpid()} could not get lock.") + return None diff --git a/src/training/prepare_experiments.py b/src/training/prepare_experiments.py new file mode 100644 index 0000000..1ab8f00 --- /dev/null +++ b/src/training/prepare_experiments.py @@ -0,0 +1,35 @@ +"""Run a experiment from a config file.""" +import json + +import click +from loguru import logger +import yaml + + +def run_experiment(experiment_filename: str) -> None: + """Run experiment from file.""" + with open(experiment_filename) as f: + experiments_config = yaml.safe_load(f) + num_experiments = len(experiments_config["experiments"]) + for index in range(num_experiments): + experiment_config = experiments_config["experiments"][index] + experiment_config["experiment_group"] = experiments_config["experiment_group"] + print( + f"python training/run_experiment.py --gpu=-1 '{json.dumps(experiment_config)}'" + ) + + +@click.command() +@click.option( + "--experiments_filename", + required=True, + type=str, + help="Filename of Yaml file of experiments to run.", +) +def main(experiment_filename: str) -> None: + """Parse command-line arguments and run experiments from provided file.""" + run_experiment(experiment_filename) + + +if __name__ == "__main__": + main() diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index 8033f47..8296e59 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -1 +1,75 @@ """Script to run experiments.""" +import importlib +import os +from typing import Dict + +import click +import torch +from training.train import Trainer + + +def run_experiment( + experiment_config: Dict, save_weights: bool, gpu_index: int, use_wandb: bool = False +) -> None: + """Short summary.""" + # Import the data loader module and arguments. + datasets_module = importlib.import_module("text_recognizer.datasets") + data_loader_ = getattr(datasets_module, experiment_config["dataloader"]) + data_loader_args = experiment_config.get("data_loader_args", {}) + + # Import the model module and model arguments. + models_module = importlib.import_module("text_recognizer.models") + model_class_ = getattr(models_module, experiment_config["model"]) + + # Import metric. + metric_fn_ = getattr(models_module, experiment_config["metric"]) + + # Import network module and arguments. + network_module = importlib.import_module("text_recognizer.networks") + network_fn_ = getattr(network_module, experiment_config["network"]) + network_args = experiment_config.get("network_args", {}) + + # Criterion + criterion_ = getattr(torch.nn, experiment_config["criterion"]) + criterion_args = experiment_config.get("criterion_args", {}) + + # Optimizer + optimizer_ = getattr(torch.optim, experiment_config["optimizer"]) + optimizer_args = experiment_config.get("optimizer_args", {}) + + # Learning rate scheduler + lr_scheduler_ = None + lr_scheduler_args = None + if experiment_config["lr_scheduler"] is not None: + lr_scheduler_ = getattr( + torch.optim.lr_scheduler, experiment_config["lr_scheduler"] + ) + lr_scheduler_args = experiment_config.get("lr_scheduler_args", {}) + + # Device + # TODO fix gpu manager + device = None + + model = model_class_( + network_fn=network_fn_, + network_args=network_args, + data_loader=data_loader_, + data_loader_args=data_loader_args, + metrics=metric_fn_, + criterion=criterion_, + criterion_args=criterion_args, + optimizer=optimizer_, + optimizer_args=optimizer_args, + lr_scheduler=lr_scheduler_, + lr_scheduler_args=lr_scheduler_args, + device=device, + ) + + # TODO: Fix checkpoint path and wandb + trainer = Trainer( + model=model, + epochs=experiment_config["epochs"], + val_metric=experiment_config["metric"], + ) + + trainer.fit() diff --git a/src/training/train.py b/src/training/train.py index 783de02..4a452b6 100644 --- a/src/training/train.py +++ b/src/training/train.py @@ -9,6 +9,8 @@ import numpy as np import torch from tqdm import tqdm, trange from training.util import RunningAverage +import wandb + torch.backends.cudnn.benchmark = True np.random.seed(4711) @@ -30,6 +32,7 @@ class Trainer: epochs: int, val_metric: str = "accuracy", checkpoint_path: Optional[Path] = None, + use_wandb: Optional[bool] = False, ) -> None: """Initialization of the Trainer. @@ -38,6 +41,7 @@ class Trainer: epochs (int): Number of epochs to train. val_metric (str): The validation metric to evaluate the model on. Defaults to "accuracy". checkpoint_path (Optional[Path]): The path to a previously trained model. Defaults to None. + use_wandb (Optional[bool]): Sync training to wandb. """ self.model = model @@ -48,13 +52,16 @@ class Trainer: if self.checkpoint_path is not None: self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) + if use_wandb: + # TODO implement wandb logging. + pass + self.val_metric = val_metric self.best_val_metric = 0.0 logger.add(self.model.name + "_{time}.log") def train(self) -> None: """Training loop.""" - # TODO add summary # Set model to traning mode. self.model.train() @@ -93,10 +100,6 @@ class Trainer: # Perform updates using calculated gradients. self.model.optimizer.step() - # Update the learning rate scheduler. - if self.model.lr_scheduler is not None: - self.model.lr_scheduler.step() - # Compute metrics. loss_avg.update(loss.item()) output = output.data.cpu() @@ -174,8 +177,8 @@ class Trainer: return metrics_mean - def run(self) -> None: - """Training and evaluation loop.""" + def fit(self) -> None: + """Runs the training and evaluation loop.""" # Create new experiment. EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True) experiment = datetime.now().strftime("%m%d_%H%M%S") -- cgit v1.2.3-70-g09d2