diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-08-25 23:18:31 +0200 | 
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-08-25 23:18:31 +0200 | 
| commit | 0421daf6bd97596703f426ba61c401599b538eeb (patch) | |
| tree | 3346a27d09bb16e3c891a7d4f3eaf5721a2dd378 /text_recognizer/model/base.py | |
| parent | 54d8b230eedfdf587e2d2d214d65582fe78c47eb (diff) | |
Rename
Diffstat (limited to 'text_recognizer/model/base.py')
| -rw-r--r-- | text_recognizer/model/base.py | 96 | 
1 files changed, 96 insertions, 0 deletions
diff --git a/text_recognizer/model/base.py b/text_recognizer/model/base.py new file mode 100644 index 0000000..1cff796 --- /dev/null +++ b/text_recognizer/model/base.py @@ -0,0 +1,96 @@ +"""Base PyTorch Lightning model.""" +from typing import Any, Dict, Optional, Tuple, Type + +import hydra +import torch +from loguru import logger as log +from omegaconf import DictConfig +import pytorch_lightning as L +from torch import nn, Tensor +from torchmetrics import Accuracy + +from text_recognizer.data.tokenizer import Tokenizer + + +class LitBase(L.LightningModule): +    """Abstract PyTorch Lightning class.""" + +    def __init__( +        self, +        network: Type[nn.Module], +        loss_fn: Type[nn.Module], +        optimizer_config: DictConfig, +        lr_scheduler_config: Optional[DictConfig], +        tokenizer: Tokenizer, +    ) -> None: +        super().__init__() + +        self.network = network +        self.loss_fn = loss_fn +        self.optimizer_config = optimizer_config +        self.lr_scheduler_config = lr_scheduler_config +        self.tokenizer = tokenizer +        ignore_index = int(self.tokenizer.get_value("<p>")) +        # Placeholders +        self.train_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) +        self.val_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) +        self.test_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) + +    def optimizer_zero_grad( +        self, +        epoch: int, +        batch_idx: int, +        optimizer: Type[torch.optim.Optimizer], +    ) -> None: +        """Optimal way to set grads to zero.""" +        optimizer.zero_grad(set_to_none=True) + +    def _configure_optimizer(self) -> Type[torch.optim.Optimizer]: +        """Configures the optimizer.""" +        log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>") +        return hydra.utils.instantiate( +            self.optimizer_config, params=self.network.parameters() +        ) + +    def _configure_lr_schedulers( +        self, optimizer: Type[torch.optim.Optimizer] +    ) -> Optional[Dict[str, Any]]: +        """Configures the lr scheduler.""" +        log.info( +            f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>" +        ) +        monitor = self.lr_scheduler_config.pop("monitor") +        interval = self.lr_scheduler_config.pop("interval") +        return { +            "monitor": monitor, +            "interval": interval, +            "scheduler": hydra.utils.instantiate( +                self.lr_scheduler_config, optimizer=optimizer +            ), +        } + +    def configure_optimizers( +        self, +    ) -> Dict[str, Any]: +        """Configures optimizer and lr scheduler.""" +        optimizer = self._configure_optimizer() +        if self.lr_scheduler_config is not None: +            scheduler = self._configure_lr_schedulers(optimizer) +            return {"optimizer": optimizer, "lr_scheduler": scheduler} +        return {"optimizer": optimizer} + +    def forward(self, data: Tensor) -> Tensor: +        """Feedforward pass.""" +        return self.network(data) + +    def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: +        """Training step.""" +        pass + +    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: +        """Validation step.""" +        pass + +    def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: +        """Test step.""" +        pass  |