summaryrefslogtreecommitdiff
path: root/text_recognizer/models/base.py
blob: 57d319e10db725287dd6001ad03f9d58c4c0d283 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""Base PyTorch Lightning model."""
from typing import Any, Dict, List, Optional, Tuple, Type

import attr
import hydra
from loguru import logger as log
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
import torch
from torch import nn
from torch import Tensor
import torchmetrics

from text_recognizer.data.base_mapping import AbstractMapping


@attr.s(eq=False)
class BaseLitModel(LightningModule):
    """Abstract PyTorch Lightning class."""

    def __attrs_pre_init__(self) -> None:
        super().__init__()

    network: Type[nn.Module] = attr.ib()
    mapping: Type[AbstractMapping] = attr.ib()
    loss_fn: Type[nn.Module] = attr.ib()
    optimizer_configs: DictConfig = attr.ib()
    lr_scheduler_configs: Optional[DictConfig] = attr.ib()
    train_acc: torchmetrics.Accuracy = attr.ib(
        init=False, default=torchmetrics.Accuracy()
    )
    val_acc: torchmetrics.Accuracy = attr.ib(
        init=False, default=torchmetrics.Accuracy()
    )
    test_acc: torchmetrics.Accuracy = attr.ib(
        init=False, default=torchmetrics.Accuracy()
    )

    def optimizer_zero_grad(
        self,
        epoch: int,
        batch_idx: int,
        optimizer: Type[torch.optim.Optimizer],
        optimizer_idx: int,
    ) -> None:
        optimizer.zero_grad(set_to_none=True)

    def _configure_optimizer(self) -> List[Type[torch.optim.Optimizer]]:
        """Configures the optimizer."""
        optimizers = []
        for optimizer_config in self.optimizer_configs.values():
            module = self
            for m in str(optimizer_config.parameters).split("."):
                module = getattr(module, m)
            del optimizer_config.parameters
            log.info(f"Instantiating optimizer <{optimizer_config._target_}>")
            optimizers.append(
                hydra.utils.instantiate(optimizer_config, params=module.parameters())
            )
        return optimizers

    def _configure_lr_schedulers(
        self, optimizers: List[Type[torch.optim.Optimizer]]
    ) -> List[Dict[str, Any]]:
        """Configures the lr scheduler."""
        if self.lr_scheduler_configs is None:
            return []
        schedulers = []
        for optimizer, lr_scheduler_config in zip(
            optimizers, self.lr_scheduler_configs.values()
        ):
            # Extract non-class arguments.
            monitor = lr_scheduler_config.monitor
            interval = lr_scheduler_config.interval
            del lr_scheduler_config.monitor
            del lr_scheduler_config.interval

            log.info(
                f"Instantiating learning rate scheduler <{lr_scheduler_config._target_}>"
            )
            scheduler = {
                "monitor": monitor,
                "interval": interval,
                "scheduler": hydra.utils.instantiate(
                    lr_scheduler_config, optimizer=optimizer
                ),
            }
            schedulers.append(scheduler)

        return schedulers

    def configure_optimizers(
        self,
    ) -> Tuple[List[Type[torch.optim.Optimizer]], List[Dict[str, Any]]]:
        """Configures optimizer and lr scheduler."""
        optimizers = self._configure_optimizer()
        schedulers = self._configure_lr_schedulers(optimizers)
        return optimizers, schedulers

    def forward(self, data: Tensor) -> Tensor:
        """Feedforward pass."""
        return self.network(data)

    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        """Training step."""
        data, targets = batch
        logits = self(data)
        loss = self.loss_fn(logits, targets)
        self.log("train/loss", loss)
        self.train_acc(logits, targets)
        self.log("train/acc", self.train_acc, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
        """Validation step."""
        data, targets = batch
        logits = self(data)
        loss = self.loss_fn(logits, targets)
        self.log("val/loss", loss, prog_bar=True)
        self.val_acc(logits, targets)
        self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)

    def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
        """Test step."""
        data, targets = batch
        logits = self(data)
        self.test_acc(logits, targets)
        self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)