summaryrefslogtreecommitdiff
path: root/src/training
diff options
context:
space:
mode:
Diffstat (limited to 'src/training')
-rw-r--r--src/training/experiments/sample_experiment.yml37
-rw-r--r--src/training/prepare_experiments.py6
-rw-r--r--src/training/run_experiment.py19
-rw-r--r--src/training/trainer/__init__.py2
-rw-r--r--src/training/trainer/callbacks/__init__.py (renamed from src/training/callbacks/__init__.py)2
-rw-r--r--src/training/trainer/callbacks/base.py (renamed from src/training/callbacks/base.py)50
-rw-r--r--src/training/trainer/callbacks/early_stopping.py (renamed from src/training/callbacks/early_stopping.py)5
-rw-r--r--src/training/trainer/callbacks/lr_schedulers.py (renamed from src/training/callbacks/lr_schedulers.py)12
-rw-r--r--src/training/trainer/callbacks/progress_bar.py61
-rw-r--r--src/training/trainer/callbacks/wandb_callbacks.py (renamed from src/training/callbacks/wandb_callbacks.py)8
-rw-r--r--src/training/trainer/population_based_training/__init__.py (renamed from src/training/population_based_training/__init__.py)0
-rw-r--r--src/training/trainer/population_based_training/population_based_training.py (renamed from src/training/population_based_training/population_based_training.py)0
-rw-r--r--src/training/trainer/train.py (renamed from src/training/train.py)87
-rw-r--r--src/training/trainer/util.py (renamed from src/training/util.py)0
14 files changed, 173 insertions, 116 deletions
diff --git a/src/training/experiments/sample_experiment.yml b/src/training/experiments/sample_experiment.yml
index 355305c..bae02ac 100644
--- a/src/training/experiments/sample_experiment.yml
+++ b/src/training/experiments/sample_experiment.yml
@@ -9,25 +9,32 @@ experiments:
seed: 4711
data_loader_args:
splits: [train, val]
- batch_size: 256
shuffle: true
num_workers: 8
cuda: true
model: CharacterModel
metrics: [accuracy]
- network: MLP
+ # network: MLP
+ # network_args:
+ # input_size: 784
+ # hidden_size: 512
+ # output_size: 80
+ # num_layers: 3
+ # dropout_rate: 0
+ # activation_fn: SELU
+ network: ResidualNetwork
network_args:
- input_size: 784
- output_size: 62
- num_layers: 3
- activation_fn: GELU
+ in_channels: 1
+ num_classes: 80
+ depths: [1, 1]
+ block_sizes: [128, 256]
# network: LeNet
# network_args:
# output_size: 62
# activation_fn: GELU
train_args:
batch_size: 256
- epochs: 16
+ epochs: 32
criterion: CrossEntropyLoss
criterion_args:
weight: null
@@ -43,20 +50,24 @@ experiments:
# centered: false
optimizer: AdamW
optimizer_args:
- lr: 1.e-2
+ lr: 1.e-03
betas: [0.9, 0.999]
eps: 1.e-08
- weight_decay: 0
+ # weight_decay: 5.e-4
amsgrad: false
# lr_scheduler: null
lr_scheduler: OneCycleLR
lr_scheduler_args:
- max_lr: 1.e-3
- epochs: 16
- callbacks: [Checkpoint, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR]
+ max_lr: 1.e-03
+ epochs: 32
+ anneal_strategy: linear
+ callbacks: [Checkpoint, ProgressBar, EarlyStopping, WandbCallback, WandbImageLogger, OneCycleLR]
callback_args:
Checkpoint:
monitor: val_accuracy
+ ProgressBar:
+ epochs: 32
+ log_batch_frequency: 100
EarlyStopping:
monitor: val_loss
min_delta: 0.0
@@ -68,5 +79,5 @@ experiments:
num_examples: 4
OneCycleLR:
null
- verbosity: 2 # 0, 1, 2
+ verbosity: 1 # 0, 1, 2
resume_experiment: null
diff --git a/src/training/prepare_experiments.py b/src/training/prepare_experiments.py
index 97c0304..4c3f9ba 100644
--- a/src/training/prepare_experiments.py
+++ b/src/training/prepare_experiments.py
@@ -7,11 +7,11 @@ from loguru import logger
import yaml
-# flake8: noqa: S404,S607,S603
def run_experiments(experiments_filename: str) -> None:
"""Run experiment from file."""
with open(experiments_filename) as f:
experiments_config = yaml.safe_load(f)
+
num_experiments = len(experiments_config["experiments"])
for index in range(num_experiments):
experiment_config = experiments_config["experiments"][index]
@@ -27,10 +27,10 @@ def run_experiments(experiments_filename: str) -> None:
type=str,
help="Filename of Yaml file of experiments to run.",
)
-def main(experiments_filename: str) -> None:
+def run_cli(experiments_filename: str) -> None:
"""Parse command-line arguments and run experiments from provided file."""
run_experiments(experiments_filename)
if __name__ == "__main__":
- main()
+ run_cli()
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index d278dc2..8c063ff 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -6,18 +6,20 @@ import json
import os
from pathlib import Path
import re
-from typing import Callable, Dict, Tuple
+from typing import Callable, Dict, Tuple, Type
import click
from loguru import logger
import torch
from tqdm import tqdm
-from training.callbacks import CallbackList
from training.gpu_manager import GPUManager
-from training.train import Trainer
+from training.trainer.callbacks import CallbackList
+from training.trainer.train import Trainer
import wandb
import yaml
+from text_recognizer.models import Model
+
EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
@@ -35,7 +37,7 @@ def get_level(experiment_config: Dict) -> int:
return 10
-def create_experiment_dir(model: Callable, experiment_config: Dict) -> Path:
+def create_experiment_dir(model: Type[Model], experiment_config: Dict) -> Path:
"""Create new experiment."""
EXPERIMENTS_DIRNAME.mkdir(parents=True, exist_ok=True)
experiment_dir = EXPERIMENTS_DIRNAME / model.__name__
@@ -67,6 +69,8 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
"""Loads all modules and arguments."""
# Import the data loader arguments.
data_loader_args = experiment_config.get("data_loader_args", {})
+ train_args = experiment_config.get("train_args", {})
+ data_loader_args["batch_size"] = train_args["batch_size"]
data_loader_args["dataset"] = experiment_config["dataset"]
data_loader_args["dataset_args"] = experiment_config.get("dataset_args", {})
@@ -94,7 +98,7 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]
optimizer_args = experiment_config.get("optimizer_args", {})
# Callbacks
- callback_modules = importlib.import_module("training.callbacks")
+ callback_modules = importlib.import_module("training.trainer.callbacks")
callbacks = [
getattr(callback_modules, callback)(
**check_args(experiment_config["callback_args"][callback])
@@ -208,6 +212,7 @@ def run_experiment(
with open(str(config_path), "w") as f:
yaml.dump(experiment_config, f)
+ # Train the model.
trainer = Trainer(
model=model,
model_dir=model_dir,
@@ -247,7 +252,7 @@ def run_experiment(
@click.option(
"--nowandb", is_flag=False, help="If true, do not use wandb for this run."
)
-def main(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None:
+def run_cli(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None:
"""Run experiment."""
if gpu < 0:
gpu_manager = GPUManager(True)
@@ -260,4 +265,4 @@ def main(experiment_config: str, gpu: int, save: bool, nowandb: bool) -> None:
if __name__ == "__main__":
- main()
+ run_cli()
diff --git a/src/training/trainer/__init__.py b/src/training/trainer/__init__.py
new file mode 100644
index 0000000..de41bfb
--- /dev/null
+++ b/src/training/trainer/__init__.py
@@ -0,0 +1,2 @@
+"""Trainer modules."""
+from .train import Trainer
diff --git a/src/training/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py
index fbcc285..5942276 100644
--- a/src/training/callbacks/__init__.py
+++ b/src/training/trainer/callbacks/__init__.py
@@ -2,6 +2,7 @@
from .base import Callback, CallbackList, Checkpoint
from .early_stopping import EarlyStopping
from .lr_schedulers import CyclicLR, MultiStepLR, OneCycleLR, ReduceLROnPlateau, StepLR
+from .progress_bar import ProgressBar
from .wandb_callbacks import WandbCallback, WandbImageLogger
__all__ = [
@@ -14,6 +15,7 @@ __all__ = [
"CyclicLR",
"MultiStepLR",
"OneCycleLR",
+ "ProgressBar",
"ReduceLROnPlateau",
"StepLR",
]
diff --git a/src/training/callbacks/base.py b/src/training/trainer/callbacks/base.py
index e0d91e6..8df94f3 100644
--- a/src/training/callbacks/base.py
+++ b/src/training/trainer/callbacks/base.py
@@ -1,7 +1,7 @@
"""Metaclass for callback functions."""
from enum import Enum
-from typing import Callable, Dict, List, Type, Union
+from typing import Callable, Dict, List, Optional, Type, Union
from loguru import logger
import numpy as np
@@ -36,27 +36,29 @@ class Callback:
"""Called when fit ends."""
pass
- def on_epoch_begin(self, epoch: int, logs: Dict = {}) -> None:
+ def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Called at the beginning of an epoch. Only used in training mode."""
pass
- def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None:
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch. Only used in training mode."""
pass
- def on_train_batch_begin(self, batch: int, logs: Dict = {}) -> None:
+ def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the beginning of an epoch."""
pass
- def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch."""
pass
- def on_validation_batch_begin(self, batch: int, logs: Dict = {}) -> None:
+ def on_validation_batch_begin(
+ self, batch: int, logs: Optional[Dict] = None
+ ) -> None:
"""Called at the beginning of an epoch."""
pass
- def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch."""
pass
@@ -102,18 +104,18 @@ class CallbackList:
for callback in self._callbacks:
callback.on_fit_end()
- def on_epoch_begin(self, epoch: int, logs: Dict = {}) -> None:
+ def on_epoch_begin(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Called at the beginning of an epoch."""
for callback in self._callbacks:
callback.on_epoch_begin(epoch, logs)
- def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None:
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch."""
for callback in self._callbacks:
callback.on_epoch_end(epoch, logs)
def _call_batch_hook(
- self, mode: str, hook: str, batch: int, logs: Dict = {}
+ self, mode: str, hook: str, batch: int, logs: Optional[Dict] = None
) -> None:
"""Helper function for all batch_{begin | end} methods."""
if hook == "begin":
@@ -123,39 +125,45 @@ class CallbackList:
else:
raise ValueError(f"Unrecognized hook {hook}.")
- def _call_batch_begin_hook(self, mode: str, batch: int, logs: Dict = {}) -> None:
+ def _call_batch_begin_hook(
+ self, mode: str, batch: int, logs: Optional[Dict] = None
+ ) -> None:
"""Helper function for all `on_*_batch_begin` methods."""
hook_name = f"on_{mode}_batch_begin"
self._call_batch_hook_helper(hook_name, batch, logs)
- def _call_batch_end_hook(self, mode: str, batch: int, logs: Dict = {}) -> None:
+ def _call_batch_end_hook(
+ self, mode: str, batch: int, logs: Optional[Dict] = None
+ ) -> None:
"""Helper function for all `on_*_batch_end` methods."""
hook_name = f"on_{mode}_batch_end"
self._call_batch_hook_helper(hook_name, batch, logs)
def _call_batch_hook_helper(
- self, hook_name: str, batch: int, logs: Dict = {}
+ self, hook_name: str, batch: int, logs: Optional[Dict] = None
) -> None:
"""Helper function for `on_*_batch_begin` methods."""
for callback in self._callbacks:
hook = getattr(callback, hook_name)
hook(batch, logs)
- def on_train_batch_begin(self, batch: int, logs: Dict = {}) -> None:
+ def on_train_batch_begin(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the beginning of an epoch."""
- self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch)
+ self._call_batch_hook(self.mode_keys.TRAIN, "begin", batch, logs)
- def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch."""
- self._call_batch_hook(self.mode_keys.TRAIN, "end", batch)
+ self._call_batch_hook(self.mode_keys.TRAIN, "end", batch, logs)
- def on_validation_batch_begin(self, batch: int, logs: Dict = {}) -> None:
+ def on_validation_batch_begin(
+ self, batch: int, logs: Optional[Dict] = None
+ ) -> None:
"""Called at the beginning of an epoch."""
- self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch)
+ self._call_batch_hook(self.mode_keys.VALIDATION, "begin", batch, logs)
- def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Called at the end of an epoch."""
- self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch)
+ self._call_batch_hook(self.mode_keys.VALIDATION, "end", batch, logs)
def __iter__(self) -> iter:
"""Iter function for callback list."""
diff --git a/src/training/callbacks/early_stopping.py b/src/training/trainer/callbacks/early_stopping.py
index c9b7907..02b431f 100644
--- a/src/training/callbacks/early_stopping.py
+++ b/src/training/trainer/callbacks/early_stopping.py
@@ -4,7 +4,8 @@ from typing import Dict, Union
from loguru import logger
import numpy as np
import torch
-from training.callbacks import Callback
+from torch import Tensor
+from training.trainer.callbacks import Callback
class EarlyStopping(Callback):
@@ -95,7 +96,7 @@ class EarlyStopping(Callback):
f"Stopped training at epoch {self.stopped_epoch + 1} with early stopping."
)
- def get_monitor_value(self, logs: Dict) -> Union[torch.Tensor, None]:
+ def get_monitor_value(self, logs: Dict) -> Union[Tensor, None]:
"""Extracts the monitor value."""
monitor_value = logs.get(self.monitor)
if monitor_value is None:
diff --git a/src/training/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py
index 00c7e9b..ba2226a 100644
--- a/src/training/callbacks/lr_schedulers.py
+++ b/src/training/trainer/callbacks/lr_schedulers.py
@@ -1,7 +1,7 @@
"""Callbacks for learning rate schedulers."""
from typing import Callable, Dict, List, Optional, Type
-from training.callbacks import Callback
+from training.trainer.callbacks import Callback
from text_recognizer.models import Model
@@ -19,7 +19,7 @@ class StepLR(Callback):
self.model = model
self.lr_scheduler = self.model.lr_scheduler
- def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None:
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every epoch."""
self.lr_scheduler.step()
@@ -37,7 +37,7 @@ class MultiStepLR(Callback):
self.model = model
self.lr_scheduler = self.model.lr_scheduler
- def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None:
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every epoch."""
self.lr_scheduler.step()
@@ -55,7 +55,7 @@ class ReduceLROnPlateau(Callback):
self.model = model
self.lr_scheduler = self.model.lr_scheduler
- def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None:
+ def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every epoch."""
val_loss = logs["val_loss"]
self.lr_scheduler.step(val_loss)
@@ -74,7 +74,7 @@ class CyclicLR(Callback):
self.model = model
self.lr_scheduler = self.model.lr_scheduler
- def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every training batch."""
self.lr_scheduler.step()
@@ -92,6 +92,6 @@ class OneCycleLR(Callback):
self.model = model
self.lr_scheduler = self.model.lr_scheduler
- def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every training batch."""
self.lr_scheduler.step()
diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py
new file mode 100644
index 0000000..1970747
--- /dev/null
+++ b/src/training/trainer/callbacks/progress_bar.py
@@ -0,0 +1,61 @@
+"""Progress bar callback for the training loop."""
+from typing import Dict, Optional
+
+from tqdm import tqdm
+from training.trainer.callbacks import Callback
+
+
+class ProgressBar(Callback):
+ """A TQDM progress bar for the training loop."""
+
+ def __init__(self, epochs: int, log_batch_frequency: int = None) -> None:
+ """Initializes the tqdm callback."""
+ self.epochs = epochs
+ self.log_batch_frequency = log_batch_frequency
+ self.progress_bar = None
+ self.val_metrics = {}
+
+ def _configure_progress_bar(self) -> None:
+ """Configures the tqdm progress bar with custom bar format."""
+ self.progress_bar = tqdm(
+ total=len(self.model.data_loaders["train"]),
+ leave=True,
+ unit="step",
+ mininterval=self.log_batch_frequency,
+ bar_format="{desc} |{bar:30}| {n_fmt}/{total_fmt} ETA: {remaining} {rate_fmt}{postfix}",
+ )
+
+ def _key_abbreviations(self, logs: Dict) -> Dict:
+ """Changes the length of keys, so that the progress bar fits better."""
+
+ def rename(key: str) -> str:
+ """Renames accuracy to acc."""
+ return key.replace("accuracy", "acc")
+
+ return {rename(key): value for key, value in logs.items()}
+
+ def on_fit_begin(self) -> None:
+ """Creates a tqdm progress bar."""
+ self._configure_progress_bar()
+
+ def on_epoch_begin(self, epoch: int, logs: Optional[Dict]) -> None:
+ """Updates the description with the current epoch."""
+ self.progress_bar.reset()
+ self.progress_bar.set_description(f"Epoch {epoch}/{self.epochs}")
+
+ def on_epoch_end(self, epoch: int, logs: Dict) -> None:
+ """At the end of each epoch, the validation metrics are updated to the progress bar."""
+ self.val_metrics = logs
+ self.progress_bar.set_postfix(**self._key_abbreviations(logs))
+ self.progress_bar.update()
+
+ def on_train_batch_end(self, batch: int, logs: Dict) -> None:
+ """Updates the progress bar for each training step."""
+ if self.val_metrics:
+ logs.update(self.val_metrics)
+ self.progress_bar.set_postfix(**self._key_abbreviations(logs))
+ self.progress_bar.update()
+
+ def on_fit_end(self) -> None:
+ """Closes the tqdm progress bar."""
+ self.progress_bar.close()
diff --git a/src/training/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
index 6ada6df..e44c745 100644
--- a/src/training/callbacks/wandb_callbacks.py
+++ b/src/training/trainer/callbacks/wandb_callbacks.py
@@ -1,9 +1,9 @@
-"""Callbacks using wandb."""
+"""Callback for W&B."""
from typing import Callable, Dict, List, Optional, Type
import numpy as np
from torchvision.transforms import Compose, ToTensor
-from training.callbacks import Callback
+from training.trainer.callbacks import Callback
import wandb
from text_recognizer.datasets import Transpose
@@ -28,12 +28,12 @@ class WandbCallback(Callback):
if self.log_batch_frequency and batch % self.log_batch_frequency == 0:
wandb.log(logs, commit=True)
- def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Logs training metrics."""
if logs is not None:
self._on_batch_end(batch, logs)
- def on_validation_batch_end(self, batch: int, logs: Dict = {}) -> None:
+ def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Logs validation metrics."""
if logs is not None:
self._on_batch_end(batch, logs)
diff --git a/src/training/population_based_training/__init__.py b/src/training/trainer/population_based_training/__init__.py
index 868d739..868d739 100644
--- a/src/training/population_based_training/__init__.py
+++ b/src/training/trainer/population_based_training/__init__.py
diff --git a/src/training/population_based_training/population_based_training.py b/src/training/trainer/population_based_training/population_based_training.py
index 868d739..868d739 100644
--- a/src/training/population_based_training/population_based_training.py
+++ b/src/training/trainer/population_based_training/population_based_training.py
diff --git a/src/training/train.py b/src/training/trainer/train.py
index aaa0430..a75ae8f 100644
--- a/src/training/train.py
+++ b/src/training/trainer/train.py
@@ -7,9 +7,9 @@ from typing import Dict, List, Optional, Tuple, Type
from loguru import logger
import numpy as np
import torch
-from tqdm import tqdm, trange
-from training.callbacks import Callback, CallbackList
-from training.util import RunningAverage
+from torch import Tensor
+from training.trainer.callbacks import Callback, CallbackList
+from training.trainer.util import RunningAverage
import wandb
from text_recognizer.models import Model
@@ -46,11 +46,11 @@ class Trainer:
self.model_dir = model_dir
self.checkpoint_path = checkpoint_path
self.start_epoch = 1
- self.epochs = train_args["epochs"] + self.start_epoch
+ self.epochs = train_args["epochs"]
self.callbacks = callbacks
if self.checkpoint_path is not None:
- self.start_epoch = self.model.load_checkpoint(self.checkpoint_path) + 1
+ self.start_epoch = self.model.load_checkpoint(self.checkpoint_path)
# Parse the name of the experiment.
experiment_dir = str(self.model_dir.parents[1]).split("/")
@@ -59,7 +59,7 @@ class Trainer:
def training_step(
self,
batch: int,
- samples: Tuple[torch.Tensor, torch.Tensor],
+ samples: Tuple[Tensor, Tensor],
loss_avg: Type[RunningAverage],
) -> Dict:
"""Performs the training step."""
@@ -108,27 +108,16 @@ class Trainer:
data_loader = self.model.data_loaders["train"]
- with tqdm(
- total=len(data_loader),
- leave=False,
- unit="step",
- bar_format="{n_fmt}/{total_fmt} |{bar:30}| {remaining} {rate_inv_fmt}{postfix}",
- ) as t:
- for batch, samples in enumerate(data_loader):
- self.callbacks.on_train_batch_begin(batch)
-
- metrics = self.training_step(batch, samples, loss_avg)
-
- self.callbacks.on_train_batch_end(batch, logs=metrics)
-
- # Update Tqdm progress bar.
- t.set_postfix(**metrics)
- t.update()
+ for batch, samples in enumerate(data_loader):
+ self.callbacks.on_train_batch_begin(batch)
+ metrics = self.training_step(batch, samples, loss_avg)
+ self.callbacks.on_train_batch_end(batch, logs=metrics)
+ @torch.no_grad()
def validation_step(
self,
batch: int,
- samples: Tuple[torch.Tensor, torch.Tensor],
+ samples: Tuple[Tensor, Tensor],
loss_avg: Type[RunningAverage],
) -> Dict:
"""Performs the validation step."""
@@ -158,6 +147,12 @@ class Trainer:
return metrics
+ def _log_val_metric(self, metrics_mean: Dict, epoch: Optional[int] = None) -> None:
+ log_str = "Validation metrics " + (f"at epoch {epoch} - " if epoch else " - ")
+ logger.debug(
+ log_str + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items())
+ )
+
def validate(self, epoch: Optional[int] = None) -> Dict:
"""Runs the validation loop for one epoch."""
# Set model to eval mode.
@@ -172,41 +167,18 @@ class Trainer:
# Summary for the current eval loop.
summary = []
- with tqdm(
- total=len(data_loader),
- leave=False,
- unit="step",
- bar_format="{n_fmt}/{total_fmt} |{bar:30}| {remaining} {rate_inv_fmt}{postfix}",
- ) as t:
- with torch.no_grad():
- for batch, samples in enumerate(data_loader):
- self.callbacks.on_validation_batch_begin(batch)
-
- metrics = self.validation_step(batch, samples, loss_avg)
-
- self.callbacks.on_validation_batch_end(batch, logs=metrics)
-
- summary.append(metrics)
-
- # Update Tqdm progress bar.
- t.set_postfix(**metrics)
- t.update()
+ for batch, samples in enumerate(data_loader):
+ self.callbacks.on_validation_batch_begin(batch)
+ metrics = self.validation_step(batch, samples, loss_avg)
+ self.callbacks.on_validation_batch_end(batch, logs=metrics)
+ summary.append(metrics)
# Compute mean of all metrics.
metrics_mean = {
"val_" + metric: np.mean([x[metric] for x in summary])
for metric in summary[0]
}
- if epoch:
- logger.debug(
- f"Validation metrics at epoch {epoch} - "
- + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items())
- )
- else:
- logger.debug(
- "Validation metrics - "
- + " - ".join(f"{k}: {v:.4f}" for k, v in metrics_mean.items())
- )
+ self._log_val_metric(metrics_mean, epoch)
return metrics_mean
@@ -214,19 +186,14 @@ class Trainer:
"""Runs the training and evaluation loop."""
logger.debug(f"Running an experiment called {self.experiment_name}.")
+
+ # Set start time.
t_start = time.time()
self.callbacks.on_fit_begin()
- # TODO: fix progress bar as callback.
# Run the training loop.
- for epoch in trange(
- self.start_epoch,
- self.epochs,
- leave=False,
- bar_format="{desc}: {n_fmt}/{total_fmt} |{bar:30}| {remaining}{postfix}",
- desc="Epoch",
- ):
+ for epoch in range(self.start_epoch, self.epochs + 1):
self.callbacks.on_epoch_begin(epoch)
# Perform one training pass over the training set.
diff --git a/src/training/util.py b/src/training/trainer/util.py
index 132b2dc..132b2dc 100644
--- a/src/training/util.py
+++ b/src/training/trainer/util.py