diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-08-25 23:18:31 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2023-08-25 23:18:31 +0200 |
commit | 0421daf6bd97596703f426ba61c401599b538eeb (patch) | |
tree | 3346a27d09bb16e3c891a7d4f3eaf5721a2dd378 /text_recognizer/model/base.py | |
parent | 54d8b230eedfdf587e2d2d214d65582fe78c47eb (diff) |
Rename
Diffstat (limited to 'text_recognizer/model/base.py')
-rw-r--r-- | text_recognizer/model/base.py | 96 |
1 files changed, 96 insertions, 0 deletions
diff --git a/text_recognizer/model/base.py b/text_recognizer/model/base.py new file mode 100644 index 0000000..1cff796 --- /dev/null +++ b/text_recognizer/model/base.py @@ -0,0 +1,96 @@ +"""Base PyTorch Lightning model.""" +from typing import Any, Dict, Optional, Tuple, Type + +import hydra +import torch +from loguru import logger as log +from omegaconf import DictConfig +import pytorch_lightning as L +from torch import nn, Tensor +from torchmetrics import Accuracy + +from text_recognizer.data.tokenizer import Tokenizer + + +class LitBase(L.LightningModule): + """Abstract PyTorch Lightning class.""" + + def __init__( + self, + network: Type[nn.Module], + loss_fn: Type[nn.Module], + optimizer_config: DictConfig, + lr_scheduler_config: Optional[DictConfig], + tokenizer: Tokenizer, + ) -> None: + super().__init__() + + self.network = network + self.loss_fn = loss_fn + self.optimizer_config = optimizer_config + self.lr_scheduler_config = lr_scheduler_config + self.tokenizer = tokenizer + ignore_index = int(self.tokenizer.get_value("<p>")) + # Placeholders + self.train_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) + self.val_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) + self.test_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index) + + def optimizer_zero_grad( + self, + epoch: int, + batch_idx: int, + optimizer: Type[torch.optim.Optimizer], + ) -> None: + """Optimal way to set grads to zero.""" + optimizer.zero_grad(set_to_none=True) + + def _configure_optimizer(self) -> Type[torch.optim.Optimizer]: + """Configures the optimizer.""" + log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>") + return hydra.utils.instantiate( + self.optimizer_config, params=self.network.parameters() + ) + + def _configure_lr_schedulers( + self, optimizer: Type[torch.optim.Optimizer] + ) -> Optional[Dict[str, Any]]: + """Configures the lr scheduler.""" + log.info( + f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>" + ) + monitor = self.lr_scheduler_config.pop("monitor") + interval = self.lr_scheduler_config.pop("interval") + return { + "monitor": monitor, + "interval": interval, + "scheduler": hydra.utils.instantiate( + self.lr_scheduler_config, optimizer=optimizer + ), + } + + def configure_optimizers( + self, + ) -> Dict[str, Any]: + """Configures optimizer and lr scheduler.""" + optimizer = self._configure_optimizer() + if self.lr_scheduler_config is not None: + scheduler = self._configure_lr_schedulers(optimizer) + return {"optimizer": optimizer, "lr_scheduler": scheduler} + return {"optimizer": optimizer} + + def forward(self, data: Tensor) -> Tensor: + """Feedforward pass.""" + return self.network(data) + + def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: + """Training step.""" + pass + + def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Validation step.""" + pass + + def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Test step.""" + pass |