summaryrefslogtreecommitdiff
path: root/src/training/trainer/callbacks
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/trainer/callbacks')
-rw-r--r--src/training/trainer/callbacks/__init__.py14
-rw-r--r--src/training/trainer/callbacks/lr_schedulers.py121
-rw-r--r--src/training/trainer/callbacks/progress_bar.py1
-rw-r--r--src/training/trainer/callbacks/wandb_callbacks.py1
4 files changed, 27 insertions, 110 deletions
diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py
index c81e4bf..e1bd858 100644
--- a/src/training/trainer/callbacks/__init__.py
+++ b/src/training/trainer/callbacks/__init__.py
@@ -3,12 +3,7 @@ from .base import Callback, CallbackList
from .checkpoint import Checkpoint
from .early_stopping import EarlyStopping
from .lr_schedulers import (
- CosineAnnealingLR,
- CyclicLR,
- MultiStepLR,
- OneCycleLR,
- ReduceLROnPlateau,
- StepLR,
+ LRScheduler,
SWA,
)
from .progress_bar import ProgressBar
@@ -18,15 +13,10 @@ __all__ = [
"Callback",
"CallbackList",
"Checkpoint",
- "CosineAnnealingLR",
"EarlyStopping",
+ "LRScheduler",
"WandbCallback",
"WandbImageLogger",
- "CyclicLR",
- "MultiStepLR",
- "OneCycleLR",
"ProgressBar",
- "ReduceLROnPlateau",
- "StepLR",
"SWA",
]
diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py
index bb41d2d..907e292 100644
--- a/src/training/trainer/callbacks/lr_schedulers.py
+++ b/src/training/trainer/callbacks/lr_schedulers.py
@@ -7,113 +7,27 @@ from training.trainer.callbacks import Callback
from text_recognizer.models import Model
-class StepLR(Callback):
- """Callback for StepLR."""
+class LRScheduler(Callback):
+ """Generic learning rate scheduler callback."""
def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler
-
- def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
- """Takes a step at the end of every epoch."""
- self.lr_scheduler.step()
-
-
-class MultiStepLR(Callback):
- """Callback for MultiStepLR."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler
-
- def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
- """Takes a step at the end of every epoch."""
- self.lr_scheduler.step()
-
-
-class ReduceLROnPlateau(Callback):
- """Callback for ReduceLROnPlateau."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
super().__init__()
- self.lr_scheduler = None
def set_model(self, model: Type[Model]) -> None:
"""Sets the model and lr scheduler."""
self.model = model
- self.lr_scheduler = self.model.lr_scheduler
+ self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"]
+ self.interval = self.model.lr_scheduler["interval"]
def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every epoch."""
- val_loss = logs["val_loss"]
- self.lr_scheduler.step(val_loss)
-
-
-class CyclicLR(Callback):
- """Callback for CyclicLR."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler
-
- def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
- """Takes a step at the end of every training batch."""
- self.lr_scheduler.step()
-
-
-class OneCycleLR(Callback):
- """Callback for OneCycleLR."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler
+ if self.interval == "epoch":
+ self.lr_scheduler.step()
def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every training batch."""
- self.lr_scheduler.step()
-
-
-class CosineAnnealingLR(Callback):
- """Callback for Cosine Annealing."""
-
- def __init__(self) -> None:
- """Initializes the callback."""
- super().__init__()
- self.lr_scheduler = None
-
- def set_model(self, model: Type[Model]) -> None:
- """Sets the model and lr scheduler."""
- self.model = model
- self.lr_scheduler = self.model.lr_scheduler
-
- def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
- """Takes a step at the end of every epoch."""
- self.lr_scheduler.step()
+ if self.interval == "step":
+ self.lr_scheduler.step()
class SWA(Callback):
@@ -122,21 +36,32 @@ class SWA(Callback):
def __init__(self) -> None:
"""Initializes the callback."""
super().__init__()
+ self.lr_scheduler = None
+ self.interval = None
self.swa_scheduler = None
+ self.swa_start = None
+ self.current_epoch = 1
def set_model(self, model: Type[Model]) -> None:
"""Sets the model and lr scheduler."""
self.model = model
- self.swa_start = self.model.swa_start
- self.swa_scheduler = self.model.lr_scheduler
- self.lr_scheduler = self.model.lr_scheduler
+ self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"]
+ self.interval = self.model.lr_scheduler["interval"]
+ self.swa_scheduler = self.model.swa_scheduler["swa_scheduler"]
+ self.swa_start = self.model.swa_scheduler["swa_start"]
def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
"""Takes a step at the end of every training batch."""
if epoch > self.swa_start:
self.model.swa_network.update_parameters(self.model.network)
self.swa_scheduler.step()
- else:
+ elif self.interval == "epoch":
+ self.lr_scheduler.step()
+ self.current_epoch = epoch
+
+ def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
+ """Takes a step at the end of every training batch."""
+ if self.current_epoch < self.swa_start and self.interval == "step":
self.lr_scheduler.step()
def on_fit_end(self) -> None:
diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py
index 7829fa0..6c4305a 100644
--- a/src/training/trainer/callbacks/progress_bar.py
+++ b/src/training/trainer/callbacks/progress_bar.py
@@ -11,6 +11,7 @@ class ProgressBar(Callback):
def __init__(self, epochs: int, log_batch_frequency: int = None) -> None:
"""Initializes the tqdm callback."""
self.epochs = epochs
+ print(epochs, type(epochs))
self.log_batch_frequency = log_batch_frequency
self.progress_bar = None
self.val_metrics = {}
diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py
index 6643a44..d2df4d7 100644
--- a/src/training/trainer/callbacks/wandb_callbacks.py
+++ b/src/training/trainer/callbacks/wandb_callbacks.py
@@ -32,6 +32,7 @@ class WandbCallback(Callback):
def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
"""Logs training metrics."""
if logs is not None:
+ logs["lr"] = self.model.optimizer.param_groups[0]["lr"]
self._on_batch_end(batch, logs)
def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: