diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-20 22:18:35 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-20 22:18:35 +0200 |
commit | 1f459ba19422593de325983040e176f97cf4ffc0 (patch) | |
tree | 89fef442d5dbe0c83253e9566d1762f0704f64e2 /src/training | |
parent | 95cbdf5bc1cc9639febda23c28d8f464c998b214 (diff) |
A lot of stuff working :D. ResNet implemented!
Diffstat (limited to 'src/training')
-rw-r--r-- | src/training/experiments/sample_experiment.yml | 37 | ||||
-rw-r--r-- | src/training/prepare_experiments.py | 6 | ||||
-rw-r--r-- | src/training/run_experiment.py | 19 | ||||
-rw-r--r-- | src/training/trainer/__init__.py | 2 | ||||
-rw-r--r-- | src/training/trainer/callbacks/__init__.py (renamed from src/training/callbacks/__init__.py) | 2 | ||||
-rw-r--r-- | src/training/trainer/callbacks/base.py (renamed from src/training/callbacks/base.py) | 50 | ||||
-rw-r--r-- | src/training/trainer/callbacks/early_stopping.py (renamed from src/training/callbacks/early_stopping.py) | 5 | ||||
-rw-r--r-- | src/training/trainer/callbacks/lr_schedulers.py (renamed from src/training/callbacks/lr_schedulers.py) | 12 | ||||
-rw-r--r-- | src/training/trainer/callbacks/progress_bar.py | 61 | ||||
-rw-r--r-- | src/training/trainer/callbacks/wandb_callbacks.py (renamed from src/training/callbacks/wandb_callbacks.py) | 8 | ||||
-rw-r--r-- | src/training/trainer/population_based_training/__init__.py (renamed from src/training/population_based_training/__init__.py) | 0 | ||||
-rw-r--r-- | src/training/trainer/population_based_training/population_based_training.py (renamed from src/training/population_based_training/population_based_training.py) | 0 | ||||
-rw-r--r-- | src/training/trainer/train.py (renamed from src/training/train.py) | 87 | ||||
-rw-r--r-- | src/training/trainer/util.py (renamed from src/training/util.py) | 0 |
14 files changed, 173 insertions, 116 deletions
diff --git a/src/training/experiments/sample_experiment.yml b/src/training/experiments/sample_experiment.yml index 355305c..bae02ac 100644 --- a/src/training/experiments/sample_experiment.yml +++ b/src/training/experiments/sample_experiment.yml @@ -9,25 +9,32 @@ experiments: seed: 4711 data_loader_args: splits: [train, val] - batch_size: 256 shuffle: true num_workers: 8 cuda: true model: CharacterModel metrics: [accuracy] - network: MLP + # network: MLP + # network_args: + # input_size: 784 + # hidden_size: 512 + # output_size: 80 + # num_layers: 3 + # dropout_rate: 0 + # activation_fn: SELU + network: ResidualNetwork network_args: - input_size: 784 - output_size: 62 - num_layers: 3 - activation_fn: GELU + in_channels: 1 + num_classes: 80 + depths: [1, 1] + block_sizes: [128, 256] # network: LeNet # network_args: # output_size: 62 # activation_fn: GELU train_args: batch_size: 256 - epochs: 16 + epochs: 32 criterion: CrossEntropyLoss criterion_args: weight: null @@ -43,20 +50,24 @@ experiments: # centered: false optimizer: AdamW optimizer_args: - lr: 1.e-2 + lr: 1.e-03 betas: [0.9, 0.999] eps: 1.e-08 - weight_decay: 0 + # weight_decay: 5.e-4 amsgrad: false # lr_scheduler: null lr_scheduler: OneCycleLR lr_scheduler_args: - max_lr: 1.e-3 - epochs: 16 - callbacks: [Checkpoint, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR] + max_lr: 1.e-03 + epochs: 32 + anneal_strategy: linear + callbacks: [Checkpoint, ProgressBar, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR] callback_args: Checkpoint: monitor: val_accuracy + ProgressBar: + epochs: 32 + log_batch_frequency: 100 EarlyStopping: monitor: val_loss min_delta: 0.0 @@ -68,5 +79,5 @@ experiments: num_examples: 4 OneCycleLR: null - verbosity: 2 # 0, 1, 2 + verbosity: 1 # 0, 1, 2 resume_experiment: null diff --git a/src/training/prepare_experiments.py b/src/training/prepare_experiments.py index 97c0304..4c3f9ba 100644 --- a/src/training/prepare_experiments.py +++ b/src/training/prepare_experiments.py @@ -7,11 +7,11 @@ from loguru import logger import yaml -# flake8: noqa: S404,S607,S603 def run_experiments(experiments_filename: str) -> None: """Run experiment from file.""" with open(experiments_filename) as f: experiments_config = yaml.safe_load(f) + num_experiments = len(experiments_config["experiments"]) for index in range(num_experiments): experiment_config = experiments_config["experiments"][index] @@ -27,10 +27,10 @@ def run_experiments(experiments_filename: str) -> None: type=str, help="Filename of Yaml file of experiments to run.", ) -def main(experiments_filename: str) -> None: +def run_cli(experiments_filename: str) -> None: """Parse command-line arguments and run experiments from provided file.""" run_experiments(experiments_filename) if __name__ == "__main__": - main() + run_cli() diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index d278dc2..8c063ff 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -6,18 +6,20 @@ import json import os from pathlib import Path import re -from typing import Callable, Dict, Tuple +from typing import Callable, Dict, Tuple, Type import click from loguru import logger import torch from tqdm import tqdm -from training.callbacks import CallbackList from training.gpu_manager import GPUManager -from training.train import Trainer +from training.trainer.callbacks import CallbackList +from training.trainer.train import Trainer import wandb import yaml +from text_recognizer.models import Model + EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" @@ -35,7 +37,7 @@ def get_level(experiment_config: Dict) -> int: return 10 -def create_experiment_dir(model: Callable, experiment_config: Dict) -> Path: +def create_experiment_dir(model: Type[Model], experiment_config: Dict) -> Path: """Create new experiment.""" EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True) experiment_dir = EXPERIMENTS_DIRNAME / model.__name__ @@ -67,6 +69,8 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] """Loads all modules and arguments.""" # Import the data loader arguments. data_loader_args = experiment_config.get("data_loader_args", {}) + train_args = experiment_config.get("train_args", {}) + data_loader_args["batch_size"] = train_args["batch_size"] data_loader_args["dataset"] = experiment_config["dataset"] data_loader_args["dataset_args"] = experiment_config.get("dataset_args", {}) @@ -94,7 +98,7 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] optimizer_args = experiment_config.get("optimizer_args", {}) # Callbacks - callback_modules = importlib.import_module("training.callbacks") + callback_modules = importlib.import_module("training.trainer.callbacks") callbacks = [ getattr(callback_modules, callback)( **check_args(experiment_config["callback_args"][callback]) @@ -208,6 +212,7 @@ def run_experiment( with open(str(config_path), "w") as f: yaml.dump(experiment_config, f) + # Train the model. trainer = Trainer( model=model, model_dir=model_dir, @@ -247,7 +252,7 @@ def run_experiment( @click.option( "--nowandb", is_flag=False, help="If true, do not use wandb for this run." ) -def main(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None: +def run_cli(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None: """Run experiment.""" if gpu < 0: gpu_manager = GPUManager(True) @@ -260,4 +265,4 @@ def main(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None: if __name__ == "__main__": - main() + run_cli() diff --git a/src/training/trainer/__init__.py b/src/training/trainer/__init__.py new file mode 100644 index 0000000..de41bfb --- /dev/null +++ b/src/training/trainer/__init__.py @@ -0,0 +1,2 @@ +"""Trainer modules.""" +from .train import Trainer diff --git a/src/training/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py index fbcc285..5942276 100644 --- a/src/training/callbacks/__init__.py +++ b/src/training/trainer/callbacks/__init__.py @@ -2,6 +2,7 @@ from .base import Callback, CallbackList, Checkpoint from .early_stopping import EarlyStopping from .lr_schedulers import CyclicLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau, StepLR +from .progress_bar import ProgressBar from .wandb_callbacks import WandbCallback, WandbImageLogger __all__ = [ @@ -14,6 +15,7 @@ __all__ = [ "CyclicLR", "MultiStepLR", "OneCycleLR", + "ProgressBar", "ReduceLROnPlateau", "StepLR", ] diff --git a/src/training/callbacks/base.py b/src/training/trainer/callbacks/base.py index e0d91e6..8df94f3 100644 --- a/src/training/callbacks/base.py +++ b/src/training/trainer/callbacks/base.py @@ -1,7 +1,7 @@ """Metaclass for callback functions.""" from enum import Enum -from typing import Callable, Dict, List, Type, Union +from typing import Callable, Dict, List, Optional, Type, Union from loguru import logger import numpy as np @@ -36,27 +36,29 @@ class Callback: """Called when fit ends.""" pass - def on_epoch_begin(self, epoch: int, logs: Dict = {}) -> None: + def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: """Called at the beginning of an epoch. Only used in training mode.""" pass - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Called at the end of an epoch. Only used in training mode.""" pass - def on_train_batch_begin(self, batch: int, logs: Dict = {}) -> None: + def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: """Called at the beginning of an epoch.""" pass - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Called at the end of an epoch.""" pass - def on_validation_batch_begin(self, batch: int, logs: Dict = {}) -> None: + def on_validation_batch_begin( + self, batch: int, logs: Optional[Dict] = None + ) -> None: """Called at the beginning of an epoch.""" pass - def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None: + def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Called at the end of an epoch.""" pass @@ -102,18 +104,18 @@ class CallbackList: for callback in self._callbacks: callback.on_fit_end() - def on_epoch_begin(self, epoch: int, logs: Dict = {}) -> None: + def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None: """Called at the beginning of an epoch.""" for callback in self._callbacks: callback.on_epoch_begin(epoch, logs) - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Called at the end of an epoch.""" for callback in self._callbacks: callback.on_epoch_end(epoch, logs) def _call_batch_hook( - self, mode: str, hook: str, batch: int, logs: Dict = {} + self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None ) -> None: """Helper function for all batch_{begin | end} methods.""" if hook == "begin": @@ -123,39 +125,45 @@ class CallbackList: else: raise ValueError(f"Unrecognized hook {hook}.") - def _call_batch_begin_hook(self, mode: str, batch: int, logs: Dict = {}) -> None: + def _call_batch_begin_hook( + self, mode: str, batch: int, logs: Optional[Dict] = None + ) -> None: """Helper function for all `on_*_batch_begin` methods.""" hook_name = f"on_{mode}_batch_begin" self._call_batch_hook_helper(hook_name, batch, logs) - def _call_batch_end_hook(self, mode: str, batch: int, logs: Dict = {}) -> None: + def _call_batch_end_hook( + self, mode: str, batch: int, logs: Optional[Dict] = None + ) -> None: """Helper function for all `on_*_batch_end` methods.""" hook_name = f"on_{mode}_batch_end" self._call_batch_hook_helper(hook_name, batch, logs) def _call_batch_hook_helper( - self, hook_name: str, batch: int, logs: Dict = {} + self, hook_name: str, batch: int, logs: Optional[Dict] = None ) -> None: """Helper function for `on_*_batch_begin` methods.""" for callback in self._callbacks: hook = getattr(callback, hook_name) hook(batch, logs) - def on_train_batch_begin(self, batch: int, logs: Dict = {}) -> None: + def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None: """Called at the beginning of an epoch.""" - self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch) + self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs) - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Called at the end of an epoch.""" - self._call_batch_hook(self.mode_keys.TRAIN, "end", batch) + self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs) - def on_validation_batch_begin(self, batch: int, logs: Dict = {}) -> None: + def on_validation_batch_begin( + self, batch: int, logs: Optional[Dict] = None + ) -> None: """Called at the beginning of an epoch.""" - self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch) + self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs) - def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None: + def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Called at the end of an epoch.""" - self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch) + self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs) def __iter__(self) -> iter: """Iter function for callback list.""" diff --git a/src/training/callbacks/early_stopping.py b/src/training/trainer/callbacks/early_stopping.py index c9b7907..02b431f 100644 --- a/src/training/callbacks/early_stopping.py +++ b/src/training/trainer/callbacks/early_stopping.py @@ -4,7 +4,8 @@ from typing import Dict, Union from loguru import logger import numpy as np import torch -from training.callbacks import Callback +from torch import Tensor +from training.trainer.callbacks import Callback class EarlyStopping(Callback): @@ -95,7 +96,7 @@ class EarlyStopping(Callback): f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping." ) - def get_monitor_value(self, logs: Dict) -> Union[torch.Tensor, None]: + def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]: """Extracts the monitor value.""" monitor_value = logs.get(self.monitor) if monitor_value is None: diff --git a/src/training/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py index 00c7e9b..ba2226a 100644 --- a/src/training/callbacks/lr_schedulers.py +++ b/src/training/trainer/callbacks/lr_schedulers.py @@ -1,7 +1,7 @@ """Callbacks for learning rate schedulers.""" from typing import Callable, Dict, List, Optional, Type -from training.callbacks import Callback +from training.trainer.callbacks import Callback from text_recognizer.models import Model @@ -19,7 +19,7 @@ class StepLR(Callback): self.model = model self.lr_scheduler = self.model.lr_scheduler - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every epoch.""" self.lr_scheduler.step() @@ -37,7 +37,7 @@ class MultiStepLR(Callback): self.model = model self.lr_scheduler = self.model.lr_scheduler - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every epoch.""" self.lr_scheduler.step() @@ -55,7 +55,7 @@ class ReduceLROnPlateau(Callback): self.model = model self.lr_scheduler = self.model.lr_scheduler - def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every epoch.""" val_loss = logs["val_loss"] self.lr_scheduler.step(val_loss) @@ -74,7 +74,7 @@ class CyclicLR(Callback): self.model = model self.lr_scheduler = self.model.lr_scheduler - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every training batch.""" self.lr_scheduler.step() @@ -92,6 +92,6 @@ class OneCycleLR(Callback): self.model = model self.lr_scheduler = self.model.lr_scheduler - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every training batch.""" self.lr_scheduler.step() diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py new file mode 100644 index 0000000..1970747 --- /dev/null +++ b/src/training/trainer/callbacks/progress_bar.py @@ -0,0 +1,61 @@ +"""Progress bar callback for the training loop.""" +from typing import Dict, Optional + +from tqdm import tqdm +from training.trainer.callbacks import Callback + + +class ProgressBar(Callback): + """A TQDM progress bar for the training loop.""" + + def __init__(self, epochs: int, log_batch_frequency: int = None) -> None: + """Initializes the tqdm callback.""" + self.epochs = epochs + self.log_batch_frequency = log_batch_frequency + self.progress_bar = None + self.val_metrics = {} + + def _configure_progress_bar(self) -> None: + """Configures the tqdm progress bar with custom bar format.""" + self.progress_bar = tqdm( + total=len(self.model.data_loaders["train"]), + leave=True, + unit="step", + mininterval=self.log_batch_frequency, + bar_format="{desc} |{bar:30}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}", + ) + + def _key_abbreviations(self, logs: Dict) -> Dict: + """Changes the length of keys, so that the progress bar fits better.""" + + def rename(key: str) -> str: + """Renames accuracy to acc.""" + return key.replace("accuracy", "acc") + + return {rename(key): value for key, value in logs.items()} + + def on_fit_begin(self) -> None: + """Creates a tqdm progress bar.""" + self._configure_progress_bar() + + def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None: + """Updates the description with the current epoch.""" + self.progress_bar.reset() + self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}") + + def on_epoch_end(self, epoch: int, logs: Dict) -> None: + """At the end of each epoch, the validation metrics are updated to the progress bar.""" + self.val_metrics = logs + self.progress_bar.set_postfix(**self._key_abbreviations(logs)) + self.progress_bar.update() + + def on_train_batch_end(self, batch: int, logs: Dict) -> None: + """Updates the progress bar for each training step.""" + if self.val_metrics: + logs.update(self.val_metrics) + self.progress_bar.set_postfix(**self._key_abbreviations(logs)) + self.progress_bar.update() + + def on_fit_end(self) -> None: + """Closes the tqdm progress bar.""" + self.progress_bar.close() diff --git a/src/training/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py index 6ada6df..e44c745 100644 --- a/src/training/callbacks/wandb_callbacks.py +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -1,9 +1,9 @@ -"""Callbacks using wandb.""" +"""Callback for W&B.""" from typing import Callable, Dict, List, Optional, Type import numpy as np from torchvision.transforms import Compose, ToTensor -from training.callbacks import Callback +from training.trainer.callbacks import Callback import wandb from text_recognizer.datasets import Transpose @@ -28,12 +28,12 @@ class WandbCallback(Callback): if self.log_batch_frequency and batch % self.log_batch_frequency == 0: wandb.log(logs, commit=True) - def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Logs training metrics.""" if logs is not None: self._on_batch_end(batch, logs) - def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None: + def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Logs validation metrics.""" if logs is not None: self._on_batch_end(batch, logs) diff --git a/src/training/population_based_training/__init__.py b/src/training/trainer/population_based_training/__init__.py index 868d739..868d739 100644 --- a/src/training/population_based_training/__init__.py +++ b/src/training/trainer/population_based_training/__init__.py diff --git a/src/training/population_based_training/population_based_training.py b/src/training/trainer/population_based_training/population_based_training.py index 868d739..868d739 100644 --- a/src/training/population_based_training/population_based_training.py +++ b/src/training/trainer/population_based_training/population_based_training.py diff --git a/src/training/train.py b/src/training/trainer/train.py index aaa0430..a75ae8f 100644 --- a/src/training/train.py +++ b/src/training/trainer/train.py @@ -7,9 +7,9 @@ from typing import Dict, List, Optional, Tuple, Type from loguru import logger import numpy as np import torch -from tqdm import tqdm, trange -from training.callbacks import Callback, CallbackList -from training.util import RunningAverage +from torch import Tensor +from training.trainer.callbacks import Callback, CallbackList +from training.trainer.util import RunningAverage import wandb from text_recognizer.models import Model @@ -46,11 +46,11 @@ class Trainer: self.model_dir = model_dir self.checkpoint_path = checkpoint_path self.start_epoch = 1 - self.epochs = train_args["epochs"] + self.start_epoch + self.epochs = train_args["epochs"] self.callbacks = callbacks if self.checkpoint_path is not None: - self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) + 1 + self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) # Parse the name of the experiment. experiment_dir = str(self.model_dir.parents[1]).split("/") @@ -59,7 +59,7 @@ class Trainer: def training_step( self, batch: int, - samples: Tuple[torch.Tensor, torch.Tensor], + samples: Tuple[Tensor, Tensor], loss_avg: Type[RunningAverage], ) -> Dict: """Performs the training step.""" @@ -108,27 +108,16 @@ class Trainer: data_loader = self.model.data_loaders["train"] - with tqdm( - total=len(data_loader), - leave=False, - unit="step", - bar_format="{n_fmt}/{total_fmt} |{bar:30}| {remaining} {rate_inv_fmt}{postfix}", - ) as t: - for batch, samples in enumerate(data_loader): - self.callbacks.on_train_batch_begin(batch) - - metrics = self.training_step(batch, samples, loss_avg) - - self.callbacks.on_train_batch_end(batch, logs=metrics) - - # Update Tqdm progress bar. - t.set_postfix(**metrics) - t.update() + for batch, samples in enumerate(data_loader): + self.callbacks.on_train_batch_begin(batch) + metrics = self.training_step(batch, samples, loss_avg) + self.callbacks.on_train_batch_end(batch, logs=metrics) + @torch.no_grad() def validation_step( self, batch: int, - samples: Tuple[torch.Tensor, torch.Tensor], + samples: Tuple[Tensor, Tensor], loss_avg: Type[RunningAverage], ) -> Dict: """Performs the validation step.""" @@ -158,6 +147,12 @@ class Trainer: return metrics + def _log_val_metric(self, metrics_mean: Dict, epoch: Optional[int] = None) -> None: + log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ") + logger.debug( + log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()) + ) + def validate(self, epoch: Optional[int] = None) -> Dict: """Runs the validation loop for one epoch.""" # Set model to eval mode. @@ -172,41 +167,18 @@ class Trainer: # Summary for the current eval loop. summary = [] - with tqdm( - total=len(data_loader), - leave=False, - unit="step", - bar_format="{n_fmt}/{total_fmt} |{bar:30}| {remaining} {rate_inv_fmt}{postfix}", - ) as t: - with torch.no_grad(): - for batch, samples in enumerate(data_loader): - self.callbacks.on_validation_batch_begin(batch) - - metrics = self.validation_step(batch, samples, loss_avg) - - self.callbacks.on_validation_batch_end(batch, logs=metrics) - - summary.append(metrics) - - # Update Tqdm progress bar. - t.set_postfix(**metrics) - t.update() + for batch, samples in enumerate(data_loader): + self.callbacks.on_validation_batch_begin(batch) + metrics = self.validation_step(batch, samples, loss_avg) + self.callbacks.on_validation_batch_end(batch, logs=metrics) + summary.append(metrics) # Compute mean of all metrics. metrics_mean = { "val_" + metric: np.mean([x[metric] for x in summary]) for metric in summary[0] } - if epoch: - logger.debug( - f"Validation metrics at epoch {epoch} - " - + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()) - ) - else: - logger.debug( - "Validation metrics - " - + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items()) - ) + self._log_val_metric(metrics_mean, epoch) return metrics_mean @@ -214,19 +186,14 @@ class Trainer: """Runs the training and evaluation loop.""" logger.debug(f"Running an experiment called {self.experiment_name}.") + + # Set start time. t_start = time.time() self.callbacks.on_fit_begin() - # TODO: fix progress bar as callback. # Run the training loop. - for epoch in trange( - self.start_epoch, - self.epochs, - leave=False, - bar_format="{desc}: {n_fmt}/{total_fmt} |{bar:30}| {remaining}{postfix}", - desc="Epoch", - ): + for epoch in range(self.start_epoch, self.epochs + 1): self.callbacks.on_epoch_begin(epoch) # Perform one training pass over the training set. diff --git a/src/training/util.py b/src/training/trainer/util.py index 132b2dc..132b2dc 100644 --- a/src/training/util.py +++ b/src/training/trainer/util.py |