1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
|
"""Base PyTorch Lightning model."""
from typing import Any, Dict, List, Optional, Tuple, Type
from attrs import define, field
import hydra
from loguru import logger as log
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
import torch
from torch import nn
from torch import Tensor
import torchmetrics
from text_recognizer.data.mappings.base import AbstractMapping
@define(eq=False)
class BaseLitModel(LightningModule):
"""Abstract PyTorch Lightning class."""
def __attrs_pre_init__(self) -> None:
"""Pre init constructor."""
super().__init__()
network: Type[nn.Module] = field()
loss_fn: Type[nn.Module] = field()
optimizer_configs: DictConfig = field()
lr_scheduler_configs: Optional[DictConfig] = field()
mapping: Type[AbstractMapping] = field()
# Placeholders
train_acc: torchmetrics.Accuracy = field(
init=False, default=torchmetrics.Accuracy()
)
val_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy())
test_acc: torchmetrics.Accuracy = field(init=False, default=torchmetrics.Accuracy())
def optimizer_zero_grad(
self,
epoch: int,
batch_idx: int,
optimizer: Type[torch.optim.Optimizer],
optimizer_idx: int,
) -> None:
"""Optimal way to set grads to zero."""
optimizer.zero_grad(set_to_none=True)
def _configure_optimizer(self) -> List[Type[torch.optim.Optimizer]]:
"""Configures the optimizer."""
optimizers = []
for optimizer_config in self.optimizer_configs.values():
module = self
for m in str(optimizer_config.parameters).split("."):
module = getattr(module, m)
del optimizer_config.parameters
log.info(f"Instantiating optimizer <{optimizer_config._target_}>")
optimizers.append(
hydra.utils.instantiate(optimizer_config, params=module.parameters())
)
return optimizers
def _configure_lr_schedulers(
self, optimizers: List[Type[torch.optim.Optimizer]]
) -> List[Dict[str, Any]]:
"""Configures the lr scheduler."""
if self.lr_scheduler_configs is None:
return []
schedulers = []
for optimizer, lr_scheduler_config in zip(
optimizers, self.lr_scheduler_configs.values()
):
# Extract non-class arguments.
monitor = lr_scheduler_config.monitor
interval = lr_scheduler_config.interval
del lr_scheduler_config.monitor
del lr_scheduler_config.interval
log.info(
f"Instantiating learning rate scheduler <{lr_scheduler_config._target_}>"
)
scheduler = {
"monitor": monitor,
"interval": interval,
"scheduler": hydra.utils.instantiate(
lr_scheduler_config, optimizer=optimizer
),
}
schedulers.append(scheduler)
return schedulers
def configure_optimizers(
self,
) -> Tuple[List[Type[torch.optim.Optimizer]], List[Dict[str, Any]]]:
"""Configures optimizer and lr scheduler."""
optimizers = self._configure_optimizer()
schedulers = self._configure_lr_schedulers(optimizers)
return optimizers, schedulers
def forward(self, data: Tensor) -> Tensor:
"""Feedforward pass."""
return self.network(data)
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
"""Training step."""
pass
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Validation step."""
pass
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Test step."""
pass
|