summaryrefslogtreecommitdiff
path: root/text_recognizer/models/base.py
blob: 26cc18c00838793fd5733ab799fbe9b65b3f0dcd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""Base PyTorch Lightning model."""
from typing import Any, Dict, List, Optional, Tuple, Type

import hydra
from loguru import logger as log
from omegaconf import DictConfig
from pytorch_lightning import LightningModule
import torch
from torch import nn
from torch import Tensor
from torchmetrics import Accuracy

from text_recognizer.data.mappings import EmnistMapping


class LitBase(LightningModule):
    """Abstract PyTorch Lightning class."""

    def __init__(
        self,
        network: Type[nn.Module],
        loss_fn: Type[nn.Module],
        optimizer_configs: DictConfig,
        lr_scheduler_configs: Optional[DictConfig],
        mapping: EmnistMapping,
    ) -> None:
        super().__init__()

        self.network = network
        self.loss_fn = loss_fn
        self.optimizer_configs = optimizer_configs
        self.lr_scheduler_configs = lr_scheduler_configs
        self.mapping = mapping

        # Placeholders
        self.train_acc = Accuracy()
        self.val_acc = Accuracy()
        self.test_acc = Accuracy()

    def optimizer_zero_grad(
        self,
        epoch: int,
        batch_idx: int,
        optimizer: Type[torch.optim.Optimizer],
        optimizer_idx: int,
    ) -> None:
        """Optimal way to set grads to zero."""
        optimizer.zero_grad(set_to_none=True)

    def _configure_optimizer(self) -> List[Type[torch.optim.Optimizer]]:
        """Configures the optimizer."""
        optimizers = []
        for optimizer_config in self.optimizer_configs.values():
            module = self
            for m in str(optimizer_config.parameters).split("."):
                module = getattr(module, m)
            del optimizer_config.parameters
            log.info(f"Instantiating optimizer <{optimizer_config._target_}>")
            optimizers.append(
                hydra.utils.instantiate(optimizer_config, params=module.parameters())
            )
        return optimizers

    def _configure_lr_schedulers(
        self, optimizers: List[Type[torch.optim.Optimizer]]
    ) -> List[Dict[str, Any]]:
        """Configures the lr scheduler."""
        if self.lr_scheduler_configs is None:
            return []
        schedulers = []
        for optimizer, lr_scheduler_config in zip(
            optimizers, self.lr_scheduler_configs.values()
        ):
            # Extract non-class arguments.
            monitor = lr_scheduler_config.monitor
            interval = lr_scheduler_config.interval
            del lr_scheduler_config.monitor
            del lr_scheduler_config.interval

            log.info(
                f"Instantiating learning rate scheduler <{lr_scheduler_config._target_}>"
            )
            scheduler = {
                "monitor": monitor,
                "interval": interval,
                "scheduler": hydra.utils.instantiate(
                    lr_scheduler_config, optimizer=optimizer
                ),
            }
            schedulers.append(scheduler)

        return schedulers

    def configure_optimizers(
        self,
    ) -> Tuple[List[Type[torch.optim.Optimizer]], List[Dict[str, Any]]]:
        """Configures optimizer and lr scheduler."""
        optimizers = self._configure_optimizer()
        schedulers = self._configure_lr_schedulers(optimizers)
        return optimizers, schedulers

    def forward(self, data: Tensor) -> Tensor:
        """Feedforward pass."""
        return self.network(data)

    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
        """Training step."""
        pass

    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
        """Validation step."""
        pass

    def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
        """Test step."""
        pass