diff options
Diffstat (limited to 'src/training/trainer/callbacks')
-rw-r--r-- | src/training/trainer/callbacks/__init__.py | 14 | ||||
-rw-r--r-- | src/training/trainer/callbacks/lr_schedulers.py | 121 | ||||
-rw-r--r-- | src/training/trainer/callbacks/progress_bar.py | 1 | ||||
-rw-r--r-- | src/training/trainer/callbacks/wandb_callbacks.py | 1 |
4 files changed, 27 insertions, 110 deletions
diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py index c81e4bf..e1bd858 100644 --- a/src/training/trainer/callbacks/__init__.py +++ b/src/training/trainer/callbacks/__init__.py @@ -3,12 +3,7 @@ from .base import Callback, CallbackList from .checkpoint import Checkpoint from .early_stopping import EarlyStopping from .lr_schedulers import ( - CosineAnnealingLR, - CyclicLR, - MultiStepLR, - OneCycleLR, - ReduceLROnPlateau, - StepLR, + LRScheduler, SWA, ) from .progress_bar import ProgressBar @@ -18,15 +13,10 @@ __all__ = [ "Callback", "CallbackList", "Checkpoint", - "CosineAnnealingLR", "EarlyStopping", + "LRScheduler", "WandbCallback", "WandbImageLogger", - "CyclicLR", - "MultiStepLR", - "OneCycleLR", "ProgressBar", - "ReduceLROnPlateau", - "StepLR", "SWA", ] diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py index bb41d2d..907e292 100644 --- a/src/training/trainer/callbacks/lr_schedulers.py +++ b/src/training/trainer/callbacks/lr_schedulers.py @@ -7,113 +7,27 @@ from training.trainer.callbacks import Callback from text_recognizer.models import Model -class StepLR(Callback): - """Callback for StepLR.""" +class LRScheduler(Callback): + """Generic learning rate scheduler callback.""" def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every epoch.""" - self.lr_scheduler.step() - - -class MultiStepLR(Callback): - """Callback for MultiStepLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every epoch.""" - self.lr_scheduler.step() - - -class ReduceLROnPlateau(Callback): - """Callback for ReduceLROnPlateau.""" - - def __init__(self) -> None: - """Initializes the callback.""" super().__init__() - self.lr_scheduler = None def set_model(self, model: Type[Model]) -> None: """Sets the model and lr scheduler.""" self.model = model - self.lr_scheduler = self.model.lr_scheduler + self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"] + self.interval = self.model.lr_scheduler["interval"] def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every epoch.""" - val_loss = logs["val_loss"] - self.lr_scheduler.step(val_loss) - - -class CyclicLR(Callback): - """Callback for CyclicLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every training batch.""" - self.lr_scheduler.step() - - -class OneCycleLR(Callback): - """Callback for OneCycleLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler + if self.interval == "epoch": + self.lr_scheduler.step() def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every training batch.""" - self.lr_scheduler.step() - - -class CosineAnnealingLR(Callback): - """Callback for Cosine Annealing.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every epoch.""" - self.lr_scheduler.step() + if self.interval == "step": + self.lr_scheduler.step() class SWA(Callback): @@ -122,21 +36,32 @@ class SWA(Callback): def __init__(self) -> None: """Initializes the callback.""" super().__init__() + self.lr_scheduler = None + self.interval = None self.swa_scheduler = None + self.swa_start = None + self.current_epoch = 1 def set_model(self, model: Type[Model]) -> None: """Sets the model and lr scheduler.""" self.model = model - self.swa_start = self.model.swa_start - self.swa_scheduler = self.model.lr_scheduler - self.lr_scheduler = self.model.lr_scheduler + self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"] + self.interval = self.model.lr_scheduler["interval"] + self.swa_scheduler = self.model.swa_scheduler["swa_scheduler"] + self.swa_start = self.model.swa_scheduler["swa_start"] def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every training batch.""" if epoch > self.swa_start: self.model.swa_network.update_parameters(self.model.network) self.swa_scheduler.step() - else: + elif self.interval == "epoch": + self.lr_scheduler.step() + self.current_epoch = epoch + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every training batch.""" + if self.current_epoch < self.swa_start and self.interval == "step": self.lr_scheduler.step() def on_fit_end(self) -> None: diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py index 7829fa0..6c4305a 100644 --- a/src/training/trainer/callbacks/progress_bar.py +++ b/src/training/trainer/callbacks/progress_bar.py @@ -11,6 +11,7 @@ class ProgressBar(Callback): def __init__(self, epochs: int, log_batch_frequency: int = None) -> None: """Initializes the tqdm callback.""" self.epochs = epochs + print(epochs, type(epochs)) self.log_batch_frequency = log_batch_frequency self.progress_bar = None self.val_metrics = {} diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py index 6643a44..d2df4d7 100644 --- a/src/training/trainer/callbacks/wandb_callbacks.py +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -32,6 +32,7 @@ class WandbCallback(Callback): def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Logs training metrics.""" if logs is not None: + logs["lr"] = self.model.optimizer.param_groups[0]["lr"] self._on_batch_end(batch, logs) def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: |