summaryrefslogtreecommitdiff
path: root/src/training/trainer/callbacks/lr_schedulers.py
blob: bb41d2d606b1a732f329ab0d4b369ebb4e9a726e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
"""Callbacks for learning rate schedulers."""
from typing import Callable, Dict, List, Optional, Type

from torch.optim.swa_utils import update_bn
from training.trainer.callbacks import Callback

from text_recognizer.models import Model


class StepLR(Callback):
    """Callback for StepLR."""

    def __init__(self) -> None:
        """Initializes the callback."""
        super().__init__()
        self.lr_scheduler = None

    def set_model(self, model: Type[Model]) -> None:
        """Sets the model and lr scheduler."""
        self.model = model
        self.lr_scheduler = self.model.lr_scheduler

    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
        """Takes a step at the end of every epoch."""
        self.lr_scheduler.step()


class MultiStepLR(Callback):
    """Callback for MultiStepLR."""

    def __init__(self) -> None:
        """Initializes the callback."""
        super().__init__()
        self.lr_scheduler = None

    def set_model(self, model: Type[Model]) -> None:
        """Sets the model and lr scheduler."""
        self.model = model
        self.lr_scheduler = self.model.lr_scheduler

    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
        """Takes a step at the end of every epoch."""
        self.lr_scheduler.step()


class ReduceLROnPlateau(Callback):
    """Callback for ReduceLROnPlateau."""

    def __init__(self) -> None:
        """Initializes the callback."""
        super().__init__()
        self.lr_scheduler = None

    def set_model(self, model: Type[Model]) -> None:
        """Sets the model and lr scheduler."""
        self.model = model
        self.lr_scheduler = self.model.lr_scheduler

    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
        """Takes a step at the end of every epoch."""
        val_loss = logs["val_loss"]
        self.lr_scheduler.step(val_loss)


class CyclicLR(Callback):
    """Callback for CyclicLR."""

    def __init__(self) -> None:
        """Initializes the callback."""
        super().__init__()
        self.lr_scheduler = None

    def set_model(self, model: Type[Model]) -> None:
        """Sets the model and lr scheduler."""
        self.model = model
        self.lr_scheduler = self.model.lr_scheduler

    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
        """Takes a step at the end of every training batch."""
        self.lr_scheduler.step()


class OneCycleLR(Callback):
    """Callback for OneCycleLR."""

    def __init__(self) -> None:
        """Initializes the callback."""
        super().__init__()
        self.lr_scheduler = None

    def set_model(self, model: Type[Model]) -> None:
        """Sets the model and lr scheduler."""
        self.model = model
        self.lr_scheduler = self.model.lr_scheduler

    def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None:
        """Takes a step at the end of every training batch."""
        self.lr_scheduler.step()


class CosineAnnealingLR(Callback):
    """Callback for Cosine Annealing."""

    def __init__(self) -> None:
        """Initializes the callback."""
        super().__init__()
        self.lr_scheduler = None

    def set_model(self, model: Type[Model]) -> None:
        """Sets the model and lr scheduler."""
        self.model = model
        self.lr_scheduler = self.model.lr_scheduler

    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
        """Takes a step at the end of every epoch."""
        self.lr_scheduler.step()


class SWA(Callback):
    """Stochastic Weight Averaging callback."""

    def __init__(self) -> None:
        """Initializes the callback."""
        super().__init__()
        self.swa_scheduler = None

    def set_model(self, model: Type[Model]) -> None:
        """Sets the model and lr scheduler."""
        self.model = model
        self.swa_start = self.model.swa_start
        self.swa_scheduler = self.model.lr_scheduler
        self.lr_scheduler = self.model.lr_scheduler

    def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None:
        """Takes a step at the end of every training batch."""
        if epoch > self.swa_start:
            self.model.swa_network.update_parameters(self.model.network)
            self.swa_scheduler.step()
        else:
            self.lr_scheduler.step()

    def on_fit_end(self) -> None:
        """Update batch norm statistics for the swa model at the end of training."""
        if self.model.swa_network:
            update_bn(
                self.model.val_dataloader(),
                self.model.swa_network,
                device=self.model.device,
            )