summaryrefslogtreecommitdiff
path: root/text_recognizer/model/base.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2023-08-25 23:18:31 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2023-08-25 23:18:31 +0200
commit0421daf6bd97596703f426ba61c401599b538eeb (patch)
tree3346a27d09bb16e3c891a7d4f3eaf5721a2dd378 /text_recognizer/model/base.py
parent54d8b230eedfdf587e2d2d214d65582fe78c47eb (diff)
Rename
Diffstat (limited to 'text_recognizer/model/base.py')
-rw-r--r--text_recognizer/model/base.py96
1 files changed, 96 insertions, 0 deletions
diff --git a/text_recognizer/model/base.py b/text_recognizer/model/base.py
new file mode 100644
index 0000000..1cff796
--- /dev/null
+++ b/text_recognizer/model/base.py
@@ -0,0 +1,96 @@
+"""Base PyTorch Lightning model."""
+from typing import Any, Dict, Optional, Tuple, Type
+
+import hydra
+import torch
+from loguru import logger as log
+from omegaconf import DictConfig
+import pytorch_lightning as L
+from torch import nn, Tensor
+from torchmetrics import Accuracy
+
+from text_recognizer.data.tokenizer import Tokenizer
+
+
+class LitBase(L.LightningModule):
+ """Abstract PyTorch Lightning class."""
+
+ def __init__(
+ self,
+ network: Type[nn.Module],
+ loss_fn: Type[nn.Module],
+ optimizer_config: DictConfig,
+ lr_scheduler_config: Optional[DictConfig],
+ tokenizer: Tokenizer,
+ ) -> None:
+ super().__init__()
+
+ self.network = network
+ self.loss_fn = loss_fn
+ self.optimizer_config = optimizer_config
+ self.lr_scheduler_config = lr_scheduler_config
+ self.tokenizer = tokenizer
+ ignore_index = int(self.tokenizer.get_value("<p>"))
+ # Placeholders
+ self.train_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index)
+ self.val_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index)
+ self.test_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index)
+
+ def optimizer_zero_grad(
+ self,
+ epoch: int,
+ batch_idx: int,
+ optimizer: Type[torch.optim.Optimizer],
+ ) -> None:
+ """Optimal way to set grads to zero."""
+ optimizer.zero_grad(set_to_none=True)
+
+ def _configure_optimizer(self) -> Type[torch.optim.Optimizer]:
+ """Configures the optimizer."""
+ log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>")
+ return hydra.utils.instantiate(
+ self.optimizer_config, params=self.network.parameters()
+ )
+
+ def _configure_lr_schedulers(
+ self, optimizer: Type[torch.optim.Optimizer]
+ ) -> Optional[Dict[str, Any]]:
+ """Configures the lr scheduler."""
+ log.info(
+ f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>"
+ )
+ monitor = self.lr_scheduler_config.pop("monitor")
+ interval = self.lr_scheduler_config.pop("interval")
+ return {
+ "monitor": monitor,
+ "interval": interval,
+ "scheduler": hydra.utils.instantiate(
+ self.lr_scheduler_config, optimizer=optimizer
+ ),
+ }
+
+ def configure_optimizers(
+ self,
+ ) -> Dict[str, Any]:
+ """Configures optimizer and lr scheduler."""
+ optimizer = self._configure_optimizer()
+ if self.lr_scheduler_config is not None:
+ scheduler = self._configure_lr_schedulers(optimizer)
+ return {"optimizer": optimizer, "lr_scheduler": scheduler}
+ return {"optimizer": optimizer}
+
+ def forward(self, data: Tensor) -> Tensor:
+ """Feedforward pass."""
+ return self.network(data)
+
+ def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
+ """Training step."""
+ pass
+
+ def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Validation step."""
+ pass
+
+ def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Test step."""
+ pass