summaryrefslogtreecommitdiff
path: root/text_recognizer/model/base.py
blob: adcb8da6d5a65d972143cd24dd7f96eb51a8bac8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""Base PyTorch Lightning model."""
from typing import Any, Dict, Optional, Tuple, Type

import hydra
import torch
from loguru import logger as log
from omegaconf import DictConfig
import pytorch_lightning as L
from torch import nn, Tensor
from torchmetrics import Accuracy

from text_recognizer.data.tokenizer import Tokenizer


class LitBase(L.LightningModule):
    """Abstract PyTorch Lightning class."""

    def __init__(
        self,
        network: Type[nn.Module],
        loss_fn: Type[nn.Module],
        optimizer_config: DictConfig,
        lr_scheduler_config: Optional[DictConfig],
        tokenizer: Tokenizer,
    ) -> None:
        super().__init__()

        self.network = network
        self.loss_fn = loss_fn
        self.optimizer_config = optimizer_config
        self.lr_scheduler_config = lr_scheduler_config
        self.tokenizer = tokenizer
        ignore_index = int(self.tokenizer.get_value("<p>"))
        # Placeholders
        self.train_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index)
        self.val_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index)
        self.test_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index)

    def optimizer_zero_grad(
        self,
        epoch: int,
        batch_idx: int,
        optimizer: Type[torch.optim.Optimizer],
    ) -> None:
        """Optimal way to set grads to zero."""
        optimizer.zero_grad(set_to_none=True)

    def _configure_optimizer(self) -> Type[torch.optim.Optimizer]:
        """Configures the optimizer."""
        log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>")
        return hydra.utils.instantiate(
            self.optimizer_config, params=self.network.parameters()
        )

    def _configure_lr_schedulers(
        self, optimizer: Type[torch.optim.Optimizer]
    ) -> Optional[Dict[str, Any]]:
        """Configures the lr scheduler."""
        log.info(
            f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>"
        )
        monitor = self.lr_scheduler_config.pop("monitor")
        interval = self.lr_scheduler_config.pop("interval")
        return {
            "monitor": monitor,
            "interval": interval,
            "scheduler": hydra.utils.instantiate(
                self.lr_scheduler_config, optimizer=optimizer
            ),
        }

    def configure_optimizers(
        self,
    ) -> Dict[str, Any]:
        """Configures optimizer and lr scheduler."""
        optimizer = self._configure_optimizer()
        if self.lr_scheduler_config is not None:
            scheduler = self._configure_lr_schedulers(optimizer)
            return {"optimizer": optimizer, "lr_scheduler": scheduler}
        return {"optimizer": optimizer}

    def forward(self, data: Tensor) -> Tensor:
        """Feedforward pass."""
        return self.network(data)

    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        """Training step."""
        pass

    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
        """Validation step."""
        pass

    def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
        """Test step."""
        pass

    def is_logged_batch(self) -> bool:
        if self.trainer is None:
            return False
        else:
            return self.trainer._logger_connector.should_update_logs

    def add_on_first_batch(self, metrics: dict, output: dict, batch_idx: int) -> None:
        if batch_idx == 0:
            output.update(metrics)