summaryrefslogtreecommitdiff
path: root/text_recognizer/models/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r--text_recognizer/models/base.py97
1 files changed, 0 insertions, 97 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
deleted file mode 100644
index 4dd5cdf..0000000
--- a/text_recognizer/models/base.py
+++ /dev/null
@@ -1,97 +0,0 @@
-"""Base PyTorch Lightning model."""
-from typing import Any, Dict, Optional, Tuple, Type
-
-import hydra
-import torch
-from loguru import logger as log
-from omegaconf import DictConfig
-from pytorch_lightning import LightningModule
-from torch import nn, Tensor
-from torchmetrics import Accuracy
-
-from text_recognizer.data.tokenizer import Tokenizer
-
-
-class LitBase(LightningModule):
- """Abstract PyTorch Lightning class."""
-
- def __init__(
- self,
- network: Type[nn.Module],
- loss_fn: Type[nn.Module],
- optimizer_config: DictConfig,
- lr_scheduler_config: Optional[DictConfig],
- tokenizer: Tokenizer,
- ) -> None:
- super().__init__()
-
- self.network = network
- self.loss_fn = loss_fn
- self.optimizer_config = optimizer_config
- self.lr_scheduler_config = lr_scheduler_config
- self.tokenizer = tokenizer
- ignore_index = int(self.tokenizer.get_value("<p>"))
- # Placeholders
- self.train_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index)
- self.val_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index)
- self.test_acc = Accuracy(mdmc_reduce="samplewise", ignore_index=ignore_index)
-
- def optimizer_zero_grad(
- self,
- epoch: int,
- batch_idx: int,
- optimizer: Type[torch.optim.Optimizer],
- optimizer_idx: int,
- ) -> None:
- """Optimal way to set grads to zero."""
- optimizer.zero_grad(set_to_none=True)
-
- def _configure_optimizer(self) -> Type[torch.optim.Optimizer]:
- """Configures the optimizer."""
- log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>")
- return hydra.utils.instantiate(
- self.optimizer_config, params=self.network.parameters()
- )
-
- def _configure_lr_schedulers(
- self, optimizer: Type[torch.optim.Optimizer]
- ) -> Optional[Dict[str, Any]]:
- """Configures the lr scheduler."""
- log.info(
- f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>"
- )
- monitor = self.lr_scheduler_config.pop("monitor")
- interval = self.lr_scheduler_config.pop("interval")
- return {
- "monitor": monitor,
- "interval": interval,
- "scheduler": hydra.utils.instantiate(
- self.lr_scheduler_config, optimizer=optimizer
- ),
- }
-
- def configure_optimizers(
- self,
- ) -> Dict[str, Any]:
- """Configures optimizer and lr scheduler."""
- optimizer = self._configure_optimizer()
- if self.lr_scheduler_config is not None:
- scheduler = self._configure_lr_schedulers(optimizer)
- return {"optimizer": optimizer, "lr_scheduler": scheduler}
- return {"optimizer": optimizer}
-
- def forward(self, data: Tensor) -> Tensor:
- """Feedforward pass."""
- return self.network(data)
-
- def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
- """Training step."""
- pass
-
- def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Validation step."""
- pass
-
- def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Test step."""
- pass