diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-03 23:33:34 +0200 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-08-03 23:33:34 +0200 |
commit | 07dd14116fe1d8148fb614b160245287533620fc (patch) | |
tree | 63395d88b17a14ad453c52889fcf541e6cbbdd3e /src/training/callbacks/lr_schedulers.py | |
parent | 704451318eb6b0b600ab314cb5aabfac82416bda (diff) |
Working Emnist lines dataset.
Diffstat (limited to 'src/training/callbacks/lr_schedulers.py')
-rw-r--r-- | src/training/callbacks/lr_schedulers.py | 97 |
1 files changed, 97 insertions, 0 deletions
diff --git a/src/training/callbacks/lr_schedulers.py b/src/training/callbacks/lr_schedulers.py new file mode 100644 index 0000000..00c7e9b --- /dev/null +++ b/src/training/callbacks/lr_schedulers.py @@ -0,0 +1,97 @@ +"""Callbacks for learning rate schedulers.""" +from typing import Callable, Dict, List, Optional, Type + +from training.callbacks import Callback + +from text_recognizer.models import Model + + +class StepLR(Callback): + """Callback for StepLR.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: + """Takes a step at the end of every epoch.""" + self.lr_scheduler.step() + + +class MultiStepLR(Callback): + """Callback for MultiStepLR.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: + """Takes a step at the end of every epoch.""" + self.lr_scheduler.step() + + +class ReduceLROnPlateau(Callback): + """Callback for ReduceLROnPlateau.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_epoch_end(self, epoch: int, logs: Dict = {}) -> None: + """Takes a step at the end of every epoch.""" + val_loss = logs["val_loss"] + self.lr_scheduler.step(val_loss) + + +class CyclicLR(Callback): + """Callback for CyclicLR.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: + """Takes a step at the end of every training batch.""" + self.lr_scheduler.step() + + +class OneCycleLR(Callback): + """Callback for OneCycleLR.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_train_batch_end(self, batch: int, logs: Dict = {}) -> None: + """Takes a step at the end of every training batch.""" + self.lr_scheduler.step() |