summaryrefslogtreecommitdiff
path: root/text_recognizer/models
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models')
-rw-r--r--text_recognizer/models/base.py30
-rw-r--r--text_recognizer/models/metrics.py32
2 files changed, 53 insertions, 9 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index e86b478..0c70625 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -6,6 +6,7 @@ import pytorch_lightning as pl
import torch
from torch import nn
from torch import Tensor
+import torchmetrics
from text_recognizer import networks
@@ -13,7 +14,14 @@ from text_recognizer import networks
class BaseModel(pl.LightningModule):
"""Abstract PyTorch Lightning class."""
- def __init__(self, network_args: Dict, optimizer_args: Dict, lr_scheduler_args: Dict, criterion_args: Dict, monitor: str = "val_loss") -> None:
+ def __init__(
+ self,
+ network_args: Dict,
+ optimizer_args: Dict,
+ lr_scheduler_args: Dict,
+ criterion_args: Dict,
+ monitor: str = "val_loss",
+ ) -> None:
super().__init__()
self.monitor = monitor
self.network = getattr(networks, network_args["type"])(**network_args["args"])
@@ -22,9 +30,9 @@ class BaseModel(pl.LightningModule):
self.loss_fn = self.configure_criterion(criterion_args)
# Accuracy metric
- self.train_acc = pl.metrics.Accuracy()
- self.val_acc = pl.metrics.Accuracy()
- self.test_acc = pl.metrics.Accuracy()
+ self.train_acc = torchmetrics.Accuracy()
+ self.val_acc = torchmetrics.Accuracy()
+ self.test_acc = torchmetrics.Accuracy()
@staticmethod
def configure_criterion(criterion_args: Dict) -> Type[nn.Module]:
@@ -41,8 +49,14 @@ class BaseModel(pl.LightningModule):
optimizer = getattr(torch.optim, self.optimizer_args["type"])(**args)
args = {} or self.lr_scheduler_args["args"]
- scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])(**args)
- return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": self.monitor}
+ scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])(
+ **args
+ )
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": scheduler,
+ "monitor": self.monitor,
+ }
def forward(self, data: Tensor) -> Tensor:
"""Feedforward pass."""
@@ -55,7 +69,7 @@ class BaseModel(pl.LightningModule):
loss = self.loss_fn(logits, targets)
self.log("train_loss", loss)
self.train_acc(logits, targets)
- self.log("train_acc": self.train_acc, on_step=False, on_epoch=True)
+ self.log("train_acc", self.train_acc, on_step=False, on_epoch=True)
return loss
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
@@ -73,5 +87,3 @@ class BaseModel(pl.LightningModule):
logits = self(data)
self.test_acc(logits, targets)
self.log("test_acc", self.test_acc, on_step=False, on_epoch=True)
-
-
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py
new file mode 100644
index 0000000..58d0537
--- /dev/null
+++ b/text_recognizer/models/metrics.py
@@ -0,0 +1,32 @@
+"""Character Error Rate (CER)."""
+from typing import Sequence
+
+import editdistance
+import torch
+from torch import Tensor
+import torchmetrics
+
+
+class CharacterErrorRate(torchmetrics.Metric):
+ """Character error rate metric, computed using Levenshtein distance."""
+
+ def __init__(self, ignore_tokens: Sequence[int], *args) -> None:
+ super().__init__()
+ self.ignore_tokens = set(ignore_tokens)
+ self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum")
+ self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
+
+ def update(self, preds: Tensor, targets: Tensor) -> None:
+ """Update CER."""
+ bsz = preds.shape[0]
+ for index in range(bsz):
+ pred = [p for p in preds[index].tolist() if p not in self.ignore_tokens]
+ target = [t for t in targets[index].tolist() if t not in self.ignore_tokens]
+ distance = editdistance.distance(pred, target)
+ error = distance / max(len(pred), len(target))
+ self.error += error
+ self.total += bsz
+
+ def compute(self) -> Tensor:
+ """Compute CER."""
+ return self.error / self.total