summaryrefslogtreecommitdiff
path: root/text_recognizer/models/base.py
blob: 4e803eb9d385e38f7d4e01686aca3ab0ea83159b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Base PyTorch Lightning model."""
from typing import Any, Dict, List, Union, Tuple, Type

import attr
import hydra
import loguru.logger as log
from omegaconf import DictConfig
import pytorch_lightning as pl
import torch
from torch import nn
from torch import Tensor
import torchmetrics


@attr.s
class LitBaseModel(pl.LightningModule):
    """Abstract PyTorch Lightning class."""

    network: Type[nn.Module] = attr.ib()
    criterion_config: DictConfig = attr.ib(converter=DictConfig)
    optimizer_config: DictConfig = attr.ib(converter=DictConfig)
    lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)

    interval: str = attr.ib()
    monitor: str = attr.ib(default="val/loss")

    loss_fn = attr.ib(init=False)

    train_acc = attr.ib(init=False)
    val_acc = attr.ib(init=False)
    test_acc = attr.ib(init=False)

    def __attrs_pre_init__(self):
        super().__init__()

    def __attrs_post_init__(self):
        self.loss_fn = self.configure_criterion()

        # Accuracy metric
        self.train_acc = torchmetrics.Accuracy()
        self.val_acc = torchmetrics.Accuracy()
        self.test_acc = torchmetrics.Accuracy()

    @staticmethod
    def configure_criterion(self) -> Type[nn.Module]:
        """Returns a loss functions."""
        log.info(f"Instantiating criterion <{self.criterion_config._target_}>")
        return hydra.utils.instantiate(self.criterion_config)

    def optimizer_zero_grad(
        self,
        epoch: int,
        batch_idx: int,
        optimizer: Type[torch.optim.Optimizer],
        optimizer_idx: int,
    ) -> None:
        optimizer.zero_grad(set_to_none=True)

    def _configure_optimizer(self) -> Type[torch.optim.Optimizer]:
        """Configures the optimizer."""
        log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>")
        return hydra.utils.instantiate(self.optimizer_config, params=self.parameters())

    def _configure_lr_scheduler(
        self, optimizer: Type[torch.optim.Optimizer]
    ) -> Dict[str, Any]:
        """Configures the lr scheduler."""
        log.info(
            f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>"
        )
        scheduler = {
            "monitor": self.monitor,
            "interval": self.interval,
            "scheduler": hydra.utils.instantiate(
                self.lr_scheduler_config, optimizer=optimizer
            ),
        }
        return scheduler

    def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]:
        """Configures optimizer and lr scheduler."""
        optimizer = self._configure_optimizer()
        scheduler = self._configure_lr_scheduler(optimizer)

        return [optimizer], [scheduler]

    def forward(self, data: Tensor) -> Tensor:
        """Feedforward pass."""
        return self.network(data)

    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        """Training step."""
        data, targets = batch
        logits = self(data)
        loss = self.loss_fn(logits, targets)
        self.log("train/loss", loss)
        self.train_acc(logits, targets)
        self.log("train/acc", self.train_acc, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
        """Validation step."""
        data, targets = batch
        logits = self(data)
        loss = self.loss_fn(logits, targets)
        self.log("val/loss", loss, prog_bar=True)
        self.val_acc(logits, targets)
        self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)

    def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
        """Test step."""
        data, targets = batch
        logits = self(data)
        self.test_acc(logits, targets)
        self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)