diff options
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r-- | text_recognizer/models/base.py | 30 |
1 files changed, 21 insertions, 9 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index e86b478..0c70625 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -6,6 +6,7 @@ import pytorch_lightning as pl import torch from torch import nn from torch import Tensor +import torchmetrics from text_recognizer import networks @@ -13,7 +14,14 @@ from text_recognizer import networks class BaseModel(pl.LightningModule): """Abstract PyTorch Lightning class.""" - def __init__(self, network_args: Dict, optimizer_args: Dict, lr_scheduler_args: Dict, criterion_args: Dict, monitor: str = "val_loss") -> None: + def __init__( + self, + network_args: Dict, + optimizer_args: Dict, + lr_scheduler_args: Dict, + criterion_args: Dict, + monitor: str = "val_loss", + ) -> None: super().__init__() self.monitor = monitor self.network = getattr(networks, network_args["type"])(**network_args["args"]) @@ -22,9 +30,9 @@ class BaseModel(pl.LightningModule): self.loss_fn = self.configure_criterion(criterion_args) # Accuracy metric - self.train_acc = pl.metrics.Accuracy() - self.val_acc = pl.metrics.Accuracy() - self.test_acc = pl.metrics.Accuracy() + self.train_acc = torchmetrics.Accuracy() + self.val_acc = torchmetrics.Accuracy() + self.test_acc = torchmetrics.Accuracy() @staticmethod def configure_criterion(criterion_args: Dict) -> Type[nn.Module]: @@ -41,8 +49,14 @@ class BaseModel(pl.LightningModule): optimizer = getattr(torch.optim, self.optimizer_args["type"])(**args) args = {} or self.lr_scheduler_args["args"] - scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])(**args) - return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": self.monitor} + scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])( + **args + ) + return { + "optimizer": optimizer, + "lr_scheduler": scheduler, + "monitor": self.monitor, + } def forward(self, data: Tensor) -> Tensor: """Feedforward pass.""" @@ -55,7 +69,7 @@ class BaseModel(pl.LightningModule): loss = self.loss_fn(logits, targets) self.log("train_loss", loss) self.train_acc(logits, targets) - self.log("train_acc": self.train_acc, on_step=False, on_epoch=True) + self.log("train_acc", self.train_acc, on_step=False, on_epoch=True) return loss def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: @@ -73,5 +87,3 @@ class BaseModel(pl.LightningModule): logits = self(data) self.test_acc(logits, targets) self.log("test_acc", self.test_acc, on_step=False, on_epoch=True) - - |