diff options
Diffstat (limited to 'src/training/trainer')
-rw-r--r-- | src/training/trainer/callbacks/__init__.py | 14 | ||||
-rw-r--r-- | src/training/trainer/callbacks/lr_schedulers.py | 121 | ||||
-rw-r--r-- | src/training/trainer/callbacks/progress_bar.py | 1 | ||||
-rw-r--r-- | src/training/trainer/callbacks/wandb_callbacks.py | 1 | ||||
-rw-r--r-- | src/training/trainer/train.py | 22 |
5 files changed, 39 insertions, 120 deletions
diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py index c81e4bf..e1bd858 100644 --- a/src/training/trainer/callbacks/__init__.py +++ b/src/training/trainer/callbacks/__init__.py @@ -3,12 +3,7 @@ from .base import Callback, CallbackList from .checkpoint import Checkpoint from .early_stopping import EarlyStopping from .lr_schedulers import ( - CosineAnnealingLR, - CyclicLR, - MultiStepLR, - OneCycleLR, - ReduceLROnPlateau, - StepLR, + LRScheduler, SWA, ) from .progress_bar import ProgressBar @@ -18,15 +13,10 @@ __all__ = [ "Callback", "CallbackList", "Checkpoint", - "CosineAnnealingLR", "EarlyStopping", + "LRScheduler", "WandbCallback", "WandbImageLogger", - "CyclicLR", - "MultiStepLR", - "OneCycleLR", "ProgressBar", - "ReduceLROnPlateau", - "StepLR", "SWA", ] diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py index bb41d2d..907e292 100644 --- a/src/training/trainer/callbacks/lr_schedulers.py +++ b/src/training/trainer/callbacks/lr_schedulers.py @@ -7,113 +7,27 @@ from training.trainer.callbacks import Callback from text_recognizer.models import Model -class StepLR(Callback): - """Callback for StepLR.""" +class LRScheduler(Callback): + """Generic learning rate scheduler callback.""" def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every epoch.""" - self.lr_scheduler.step() - - -class MultiStepLR(Callback): - """Callback for MultiStepLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every epoch.""" - self.lr_scheduler.step() - - -class ReduceLROnPlateau(Callback): - """Callback for ReduceLROnPlateau.""" - - def __init__(self) -> None: - """Initializes the callback.""" super().__init__() - self.lr_scheduler = None def set_model(self, model: Type[Model]) -> None: """Sets the model and lr scheduler.""" self.model = model - self.lr_scheduler = self.model.lr_scheduler + self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"] + self.interval = self.model.lr_scheduler["interval"] def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every epoch.""" - val_loss = logs["val_loss"] - self.lr_scheduler.step(val_loss) - - -class CyclicLR(Callback): - """Callback for CyclicLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every training batch.""" - self.lr_scheduler.step() - - -class OneCycleLR(Callback): - """Callback for OneCycleLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler + if self.interval == "epoch": + self.lr_scheduler.step() def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every training batch.""" - self.lr_scheduler.step() - - -class CosineAnnealingLR(Callback): - """Callback for Cosine Annealing.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every epoch.""" - self.lr_scheduler.step() + if self.interval == "step": + self.lr_scheduler.step() class SWA(Callback): @@ -122,21 +36,32 @@ class SWA(Callback): def __init__(self) -> None: """Initializes the callback.""" super().__init__() + self.lr_scheduler = None + self.interval = None self.swa_scheduler = None + self.swa_start = None + self.current_epoch = 1 def set_model(self, model: Type[Model]) -> None: """Sets the model and lr scheduler.""" self.model = model - self.swa_start = self.model.swa_start - self.swa_scheduler = self.model.lr_scheduler - self.lr_scheduler = self.model.lr_scheduler + self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"] + self.interval = self.model.lr_scheduler["interval"] + self.swa_scheduler = self.model.swa_scheduler["swa_scheduler"] + self.swa_start = self.model.swa_scheduler["swa_start"] def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every training batch.""" if epoch > self.swa_start: self.model.swa_network.update_parameters(self.model.network) self.swa_scheduler.step() - else: + elif self.interval == "epoch": + self.lr_scheduler.step() + self.current_epoch = epoch + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every training batch.""" + if self.current_epoch < self.swa_start and self.interval == "step": self.lr_scheduler.step() def on_fit_end(self) -> None: diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py index 7829fa0..6c4305a 100644 --- a/src/training/trainer/callbacks/progress_bar.py +++ b/src/training/trainer/callbacks/progress_bar.py @@ -11,6 +11,7 @@ class ProgressBar(Callback): def __init__(self, epochs: int, log_batch_frequency: int = None) -> None: """Initializes the tqdm callback.""" self.epochs = epochs + print(epochs, type(epochs)) self.log_batch_frequency = log_batch_frequency self.progress_bar = None self.val_metrics = {} diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py index 6643a44..d2df4d7 100644 --- a/src/training/trainer/callbacks/wandb_callbacks.py +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -32,6 +32,7 @@ class WandbCallback(Callback): def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Logs training metrics.""" if logs is not None: + logs["lr"] = self.model.optimizer.param_groups[0]["lr"] self._on_batch_end(batch, logs) def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py index b240157..bd6a491 100644 --- a/src/training/trainer/train.py +++ b/src/training/trainer/train.py @@ -9,7 +9,7 @@ import numpy as np import torch from torch import Tensor from torch.optim.swa_utils import update_bn -from training.trainer.callbacks import Callback, CallbackList +from training.trainer.callbacks import Callback, CallbackList, LRScheduler, SWA from training.trainer.util import log_val_metric, RunningAverage import wandb @@ -47,8 +47,14 @@ class Trainer: self.model = None def _configure_callbacks(self) -> None: + """Instantiate the CallbackList.""" if not self.callbacks_configured: - # Instantiate a CallbackList. + # If learning rate schedulers are present, they need to be added to the callbacks. + if self.model.swa_scheduler is not None: + self.callbacks.append(SWA()) + elif self.model.lr_scheduler is not None: + self.callbacks.append(LRScheduler()) + self.callbacks = CallbackList(self.model, self.callbacks) def compute_metrics( @@ -91,7 +97,7 @@ class Trainer: # Forward pass. # Get the network prediction. - output = self.model.network(data) + output = self.model.forward(data) # Compute the loss. loss = self.model.loss_fn(output, targets) @@ -130,7 +136,6 @@ class Trainer: batch: int, samples: Tuple[Tensor, Tensor], loss_avg: Type[RunningAverage], - use_swa: bool = False, ) -> Dict: """Performs the validation step.""" # Pass the tensor to the device for computation. @@ -143,10 +148,7 @@ class Trainer: # Forward pass. # Get the network prediction. # Use SWA if available and using test dataset. - if use_swa and self.model.swa_network is None: - output = self.model.swa_network(data) - else: - output = self.model.network(data) + output = self.model.forward(data) # Compute the loss. loss = self.model.loss_fn(output, targets) @@ -238,7 +240,7 @@ class Trainer: self.model.eval() # Check if SWA network is available. - use_swa = True if self.model.swa_network is not None else False + self.model.use_swa_model() # Running average for the loss. loss_avg = RunningAverage() @@ -247,7 +249,7 @@ class Trainer: summary = [] for batch, samples in enumerate(self.model.test_dataloader()): - metrics = self.validation_step(batch, samples, loss_avg, use_swa) + metrics = self.validation_step(batch, samples, loss_avg) summary.append(metrics) # Compute mean of all test metrics. |