From 1f459ba19422593de325983040e176f97cf4ffc0 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Thu, 20 Aug 2020 22:18:35 +0200 Subject: A lot of stuff working :D. ResNet implemented! --- src/training/trainer/callbacks/lr_schedulers.py | 97 +++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 src/training/trainer/callbacks/lr_schedulers.py (limited to 'src/training/trainer/callbacks/lr_schedulers.py') diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py new file mode 100644 index 0000000..ba2226a --- /dev/null +++ b/src/training/trainer/callbacks/lr_schedulers.py @@ -0,0 +1,97 @@ +"""Callbacks for learning rate schedulers.""" +from typing import Callable, Dict, List, Optional, Type + +from training.trainer.callbacks import Callback + +from text_recognizer.models import Model + + +class StepLR(Callback): + """Callback for StepLR.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every epoch.""" + self.lr_scheduler.step() + + +class MultiStepLR(Callback): + """Callback for MultiStepLR.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every epoch.""" + self.lr_scheduler.step() + + +class ReduceLROnPlateau(Callback): + """Callback for ReduceLROnPlateau.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every epoch.""" + val_loss = logs["val_loss"] + self.lr_scheduler.step(val_loss) + + +class CyclicLR(Callback): + """Callback for CyclicLR.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every training batch.""" + self.lr_scheduler.step() + + +class OneCycleLR(Callback): + """Callback for OneCycleLR.""" + + def __init__(self) -> None: + """Initializes the callback.""" + super().__init__() + self.lr_scheduler = None + + def set_model(self, model: Type[Model]) -> None: + """Sets the model and lr scheduler.""" + self.model = model + self.lr_scheduler = self.model.lr_scheduler + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every training batch.""" + self.lr_scheduler.step() -- cgit v1.2.3-70-g09d2