"""Base PyTorch Lightning model.""" from typing import Any, Dict, Optional, Tuple, Type import hydra from loguru import logger as log from omegaconf import DictConfig from pytorch_lightning import LightningModule import torch from torch import nn from torch import Tensor from torchmetrics import Accuracy from text_recognizer.data.mappings import EmnistMapping class LitBase(LightningModule): """Abstract PyTorch Lightning class.""" def __init__( self, network: Type[nn.Module], loss_fn: Type[nn.Module], optimizer_configs: DictConfig, lr_scheduler_configs: Optional[DictConfig], mapping: EmnistMapping, ) -> None: super().__init__() self.network = network self.loss_fn = loss_fn self.optimizer_configs = optimizer_configs self.lr_scheduler_configs = lr_scheduler_configs self.mapping = mapping # Placeholders self.train_acc = Accuracy() self.val_acc = Accuracy() self.test_acc = Accuracy() def optimizer_zero_grad( self, epoch: int, batch_idx: int, optimizer: Type[torch.optim.Optimizer], optimizer_idx: int, ) -> None: """Optimal way to set grads to zero.""" optimizer.zero_grad(set_to_none=True) def _configure_optimizer(self) -> Type[torch.optim.Optimizer]: """Configures the optimizer.""" return hydra.utils.instantiate( self.optimizer_config, params=self.network.parameters() ) def _configure_lr_schedulers( self, optimizer: Type[torch.optim.Optimizer] ) -> Dict[str, Any]: """Configures the lr scheduler.""" log.info( f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>" ) monitor = self.lr_scheduler_config.monitor interval = self.lr_scheduler_config.interval del self.lr_scheduler_config.monitor del self.lr_scheduler_config.interval return { "monitor": monitor, "interval": interval, "scheduler": hydra.utils.instantiate( self.lr_scheduler_config, optimizer=optimizer ), } def configure_optimizers( self, ) -> Dict[str, Any]: """Configures optimizer and lr scheduler.""" optimizer = self._configure_optimizer() scheduler = self._configure_lr_schedulers(optimizer) return {"optimizer": optimizer, "scheduler": scheduler} def forward(self, data: Tensor) -> Tensor: """Feedforward pass.""" return self.network(data) def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: """Training step.""" pass def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Validation step.""" pass def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" pass