summaryrefslogtreecommitdiff
path: root/src/training/trainer
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/trainer')
-rw-r--r--src/training/trainer/callbacks/__init__.py14
-rw-r--r--src/training/trainer/callbacks/lr_schedulers.py121
-rw-r--r--src/training/trainer/callbacks/progress_bar.py1
-rw-r--r--src/training/trainer/callbacks/wandb_callbacks.py1
-rw-r--r--src/training/trainer/train.py22
5 files changed, 39 insertions, 120 deletions
diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py
index c81e4bf..e1bd858 100644
--- a/src/training/trainer/callbacks/__init__.py
+++ b/src/training/trainer/callbacks/__init__.py
@@ -3,12 +3,7 @@ from .base import Callback, CallbackList
from .checkpoint import Checkpoint
from .early_stopping import EarlyStopping
from .lr_schedulers import (
- CosineAnnealingLR,
- CyclicLR,
- MultiStepLR,
- OneCycleLR,
- ReduceLROnPlateau,
- StepLR,
+ LRScheduler,
SWA,
)
from .progress_bar import ProgressBar
@@ -18,15 +13,10 @@ __all__ = [
"Callback",
"CallbackList",
"Checkpoint",
- "CosineAnnealingLR",
"EarlyStopping",
+ "LRScheduler",
"WandbCallback",
"WandbImageLogger",
- "CyclicLR",
- "MultiStepLR",
- "OneCycleLR",
"ProgressBar",
- "ReduceLROnPlateau",
- "StepLR",
"SWA",
]
diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py
index bb41d2d..907e292 100644
--- a/src/training/trainer/callbacks/lr_schedulers.py
+++ b/src/training/trainer/callbacks/lr_schedulers.py
@@ -7,113 +7,27 @@ from training.trainer.callbacks import Callback
from text_recognizer.models import Model
-class StepLR(Callback):
- """Callback for StepLR."""
+class LRScheduler(Callback):
+ """Generic learning rate scheduler callback."""
def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler
-
- def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
- """Takes a step at the end of every epoch."""
- self.lr_scheduler.step()
-
-
-class MultiStepLR(Callback):
- """Callback for MultiStepLR."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler
-
- def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
- """Takes a step at the end of every epoch."""
- self.lr_scheduler.step()
-
-
-class ReduceLROnPlateau(Callback):
- """Callback for ReduceLROnPlateau."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
super().__init__()
- self.lr_scheduler = None
def set_model(self, model: Type[Model]) -> None:
"""Sets the model and lr scheduler."""
self.model = model
- self.lr_scheduler = self.model.lr_scheduler
+ self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"]
+ self.interval = self.model.lr_scheduler["interval"]
def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every epoch."""
- val_loss = logs["val_loss"]
- self.lr_scheduler.step(val_loss)
-
-
-class CyclicLR(Callback):
- """Callback for CyclicLR."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler
-
- def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Takes a step at the end of every training batch."""
- self.lr_scheduler.step()
-
-
-class OneCycleLR(Callback):
- """Callback for OneCycleLR."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler
+ if self.interval == "epoch":
+ self.lr_scheduler.step()
def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every training batch."""
- self.lr_scheduler.step()
-
-
-class CosineAnnealingLR(Callback):
- """Callback for Cosine Annealing."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler
-
- def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
- """Takes a step at the end of every epoch."""
- self.lr_scheduler.step()
+ if self.interval == "step":
+ self.lr_scheduler.step()
class SWA(Callback):
@@ -122,21 +36,32 @@ class SWA(Callback):
def __init__(self) -> None:
"""Initializes the callback."""
super().__init__()
+ self.lr_scheduler = None
+ self.interval = None
self.swa_scheduler = None
+ self.swa_start = None
+ self.current_epoch = 1
def set_model(self, model: Type[Model]) -> None:
"""Sets the model and lr scheduler."""
self.model = model
- self.swa_start = self.model.swa_start
- self.swa_scheduler = self.model.lr_scheduler
- self.lr_scheduler = self.model.lr_scheduler
+ self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"]
+ self.interval = self.model.lr_scheduler["interval"]
+ self.swa_scheduler = self.model.swa_scheduler["swa_scheduler"]
+ self.swa_start = self.model.swa_scheduler["swa_start"]
def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every training batch."""
if epoch > self.swa_start:
self.model.swa_network.update_parameters(self.model.network)
self.swa_scheduler.step()
- else:
+ elif self.interval == "epoch":
+ self.lr_scheduler.step()
+ self.current_epoch = epoch
+
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Takes a step at the end of every training batch."""
+ if self.current_epoch < self.swa_start and self.interval == "step":
self.lr_scheduler.step()
def on_fit_end(self) -> None:
diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py
index 7829fa0..6c4305a 100644
--- a/src/training/trainer/callbacks/progress_bar.py
+++ b/src/training/trainer/callbacks/progress_bar.py
@@ -11,6 +11,7 @@ class ProgressBar(Callback):
def __init__(self, epochs: int, log_batch_frequency: int = None) -> None:
"""Initializes the tqdm callback."""
self.epochs = epochs
+ print(epochs, type(epochs))
self.log_batch_frequency = log_batch_frequency
self.progress_bar = None
self.val_metrics = {}
diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
index 6643a44..d2df4d7 100644
--- a/src/training/trainer/callbacks/wandb_callbacks.py
+++ b/src/training/trainer/callbacks/wandb_callbacks.py
@@ -32,6 +32,7 @@ class WandbCallback(Callback):
def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Logs training metrics."""
if logs is not None:
+ logs["lr"] = self.model.optimizer.param_groups[0]["lr"]
self._on_batch_end(batch, logs)
def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py
index b240157..bd6a491 100644
--- a/src/training/trainer/train.py
+++ b/src/training/trainer/train.py
@@ -9,7 +9,7 @@ import numpy as np
import torch
from torch import Tensor
from torch.optim.swa_utils import update_bn
-from training.trainer.callbacks import Callback, CallbackList
+from training.trainer.callbacks import Callback, CallbackList, LRScheduler, SWA
from training.trainer.util import log_val_metric, RunningAverage
import wandb
@@ -47,8 +47,14 @@ class Trainer:
self.model = None
def _configure_callbacks(self) -> None:
+ """Instantiate the CallbackList."""
if not self.callbacks_configured:
- # Instantiate a CallbackList.
+ # If learning rate schedulers are present, they need to be added to the callbacks.
+ if self.model.swa_scheduler is not None:
+ self.callbacks.append(SWA())
+ elif self.model.lr_scheduler is not None:
+ self.callbacks.append(LRScheduler())
+
self.callbacks = CallbackList(self.model, self.callbacks)
def compute_metrics(
@@ -91,7 +97,7 @@ class Trainer:
# Forward pass.
# Get the network prediction.
- output = self.model.network(data)
+ output = self.model.forward(data)
# Compute the loss.
loss = self.model.loss_fn(output, targets)
@@ -130,7 +136,6 @@ class Trainer:
batch: int,
samples: Tuple[Tensor, Tensor],
loss_avg: Type[RunningAverage],
- use_swa: bool = False,
) -> Dict:
"""Performs the validation step."""
# Pass the tensor to the device for computation.
@@ -143,10 +148,7 @@ class Trainer:
# Forward pass.
# Get the network prediction.
# Use SWA if available and using test dataset.
- if use_swa and self.model.swa_network is None:
- output = self.model.swa_network(data)
- else:
- output = self.model.network(data)
+ output = self.model.forward(data)
# Compute the loss.
loss = self.model.loss_fn(output, targets)
@@ -238,7 +240,7 @@ class Trainer:
self.model.eval()
# Check if SWA network is available.
- use_swa = True if self.model.swa_network is not None else False
+ self.model.use_swa_model()
# Running average for the loss.
loss_avg = RunningAverage()
@@ -247,7 +249,7 @@ class Trainer:
summary = []
for batch, samples in enumerate(self.model.test_dataloader()):
- metrics = self.validation_step(batch, samples, loss_avg, use_swa)
+ metrics = self.validation_step(batch, samples, loss_avg)
summary.append(metrics)
# Compute mean of all test metrics.