summaryrefslogtreecommitdiff
path: root/text_recognizer/models/base.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 16:05:13 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 16:05:13 +0200
commit31e9673eef3088f08e3ee6aef8b78abd701ca329 (patch)
treef529d975d18d718a5d646e93f746d8be6f2f5cfe /text_recognizer/models/base.py
parent36964354407d0fdf73bdca2f611fee1664860197 (diff)
Reformat test for CER
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r--text_recognizer/models/base.py30
1 files changed, 21 insertions, 9 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index e86b478..0c70625 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -6,6 +6,7 @@ import pytorch_lightning as pl
import torch
from torch import nn
from torch import Tensor
+import torchmetrics
from text_recognizer import networks
@@ -13,7 +14,14 @@ from text_recognizer import networks
class BaseModel(pl.LightningModule):
"""Abstract PyTorch Lightning class."""
- def __init__(self, network_args: Dict, optimizer_args: Dict, lr_scheduler_args: Dict, criterion_args: Dict, monitor: str = "val_loss") -> None:
+ def __init__(
+ self,
+ network_args: Dict,
+ optimizer_args: Dict,
+ lr_scheduler_args: Dict,
+ criterion_args: Dict,
+ monitor: str = "val_loss",
+ ) -> None:
super().__init__()
self.monitor = monitor
self.network = getattr(networks, network_args["type"])(**network_args["args"])
@@ -22,9 +30,9 @@ class BaseModel(pl.LightningModule):
self.loss_fn = self.configure_criterion(criterion_args)
# Accuracy metric
- self.train_acc = pl.metrics.Accuracy()
- self.val_acc = pl.metrics.Accuracy()
- self.test_acc = pl.metrics.Accuracy()
+ self.train_acc = torchmetrics.Accuracy()
+ self.val_acc = torchmetrics.Accuracy()
+ self.test_acc = torchmetrics.Accuracy()
@staticmethod
def configure_criterion(criterion_args: Dict) -> Type[nn.Module]:
@@ -41,8 +49,14 @@ class BaseModel(pl.LightningModule):
optimizer = getattr(torch.optim, self.optimizer_args["type"])(**args)
args = {} or self.lr_scheduler_args["args"]
- scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])(**args)
- return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": self.monitor}
+ scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])(
+ **args
+ )
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": scheduler,
+ "monitor": self.monitor,
+ }
def forward(self, data: Tensor) -> Tensor:
"""Feedforward pass."""
@@ -55,7 +69,7 @@ class BaseModel(pl.LightningModule):
loss = self.loss_fn(logits, targets)
self.log("train_loss", loss)
self.train_acc(logits, targets)
- self.log("train_acc": self.train_acc, on_step=False, on_epoch=True)
+ self.log("train_acc", self.train_acc, on_step=False, on_epoch=True)
return loss
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
@@ -73,5 +87,3 @@ class BaseModel(pl.LightningModule):
logits = self(data)
self.test_acc(logits, targets)
self.log("test_acc", self.test_acc, on_step=False, on_epoch=True)
-
-