summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py1
-rw-r--r--text_recognizer/models/base.py30
-rw-r--r--text_recognizer/models/metrics.py32
-rw-r--r--text_recognizer/networks/transformer/positional_encoding.py16
4 files changed, 63 insertions, 16 deletions
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 51050fc..d2529b4 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -71,5 +71,6 @@ class IAMExtendedParagraphs(BaseDataModule):
)
return basic + data
+
def show_dataset_info() -> None:
load_and_print_info(IAMExtendedParagraphs)
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index e86b478..0c70625 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -6,6 +6,7 @@ import pytorch_lightning as pl
import torch
from torch import nn
from torch import Tensor
+import torchmetrics
from text_recognizer import networks
@@ -13,7 +14,14 @@ from text_recognizer import networks
class BaseModel(pl.LightningModule):
"""Abstract PyTorch Lightning class."""
- def __init__(self, network_args: Dict, optimizer_args: Dict, lr_scheduler_args: Dict, criterion_args: Dict, monitor: str = "val_loss") -> None:
+ def __init__(
+ self,
+ network_args: Dict,
+ optimizer_args: Dict,
+ lr_scheduler_args: Dict,
+ criterion_args: Dict,
+ monitor: str = "val_loss",
+ ) -> None:
super().__init__()
self.monitor = monitor
self.network = getattr(networks, network_args["type"])(**network_args["args"])
@@ -22,9 +30,9 @@ class BaseModel(pl.LightningModule):
self.loss_fn = self.configure_criterion(criterion_args)
# Accuracy metric
- self.train_acc = pl.metrics.Accuracy()
- self.val_acc = pl.metrics.Accuracy()
- self.test_acc = pl.metrics.Accuracy()
+ self.train_acc = torchmetrics.Accuracy()
+ self.val_acc = torchmetrics.Accuracy()
+ self.test_acc = torchmetrics.Accuracy()
@staticmethod
def configure_criterion(criterion_args: Dict) -> Type[nn.Module]:
@@ -41,8 +49,14 @@ class BaseModel(pl.LightningModule):
optimizer = getattr(torch.optim, self.optimizer_args["type"])(**args)
args = {} or self.lr_scheduler_args["args"]
- scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])(**args)
- return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": self.monitor}
+ scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])(
+ **args
+ )
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": scheduler,
+ "monitor": self.monitor,
+ }
def forward(self, data: Tensor) -> Tensor:
"""Feedforward pass."""
@@ -55,7 +69,7 @@ class BaseModel(pl.LightningModule):
loss = self.loss_fn(logits, targets)
self.log("train_loss", loss)
self.train_acc(logits, targets)
- self.log("train_acc": self.train_acc, on_step=False, on_epoch=True)
+ self.log("train_acc", self.train_acc, on_step=False, on_epoch=True)
return loss
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
@@ -73,5 +87,3 @@ class BaseModel(pl.LightningModule):
logits = self(data)
self.test_acc(logits, targets)
self.log("test_acc", self.test_acc, on_step=False, on_epoch=True)
-
-
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py
new file mode 100644
index 0000000..58d0537
--- /dev/null
+++ b/text_recognizer/models/metrics.py
@@ -0,0 +1,32 @@
+"""Character Error Rate (CER)."""
+from typing import Sequence
+
+import editdistance
+import torch
+from torch import Tensor
+import torchmetrics
+
+
+class CharacterErrorRate(torchmetrics.Metric):
+ """Character error rate metric, computed using Levenshtein distance."""
+
+ def __init__(self, ignore_tokens: Sequence[int], *args) -> None:
+ super().__init__()
+ self.ignore_tokens = set(ignore_tokens)
+ self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum")
+ self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
+
+ def update(self, preds: Tensor, targets: Tensor) -> None:
+ """Update CER."""
+ bsz = preds.shape[0]
+ for index in range(bsz):
+ pred = [p for p in preds[index].tolist() if p not in self.ignore_tokens]
+ target = [t for t in targets[index].tolist() if t not in self.ignore_tokens]
+ distance = editdistance.distance(pred, target)
+ error = distance / max(len(pred), len(target))
+ self.error += error
+ self.total += bsz
+
+ def compute(self) -> Tensor:
+ """Compute CER."""
+ return self.error / self.total
diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py
index d03f630..d67d297 100644
--- a/text_recognizer/networks/transformer/positional_encoding.py
+++ b/text_recognizer/networks/transformer/positional_encoding.py
@@ -16,7 +16,7 @@ class PositionalEncoding(nn.Module):
self.dropout = nn.Dropout(p=dropout_rate)
pe = self.make_pe(hidden_dim, max_len)
self.register_buffer("pe", pe)
-
+
@staticmethod
def make_pe(hidden_dim: int, max_len: int) -> Tensor:
"""Returns positional encoding."""
@@ -40,7 +40,7 @@ class PositionalEncoding(nn.Module):
class PositionalEncoding2D(nn.Module):
"""Positional encodings for feature maps."""
- def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int =2048) -> None:
+ def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int = 2048) -> None:
super().__init__()
if hidden_dim % 2 != 0:
raise ValueError(f"Embedding depth {hidden_dim} is not even!")
@@ -50,10 +50,14 @@ class PositionalEncoding2D(nn.Module):
def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor:
"""Returns 2d postional encoding."""
- pe_h = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h) # [H, 1, D // 2]
+ pe_h = PositionalEncoding.make_pe(
+ hidden_dim // 2, max_len=max_h
+ ) # [H, 1, D // 2]
pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w)
- pe_w = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h) # [W, 1, D // 2]
+ pe_w = PositionalEncoding.make_pe(
+ hidden_dim // 2, max_len=max_h
+ ) # [W, 1, D // 2]
pe_w = repeat(pe_w, "h w d -> d (h tile) w", tile=max_h)
pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W]
@@ -64,7 +68,5 @@ class PositionalEncoding2D(nn.Module):
# Assumes x hase shape [B, D, H, W]
if x.shape[1] != self.pe.shape[0]:
raise ValueError("Hidden dimensions does not match.")
- x += self.pe[:, :x.shape[2], :x.shape[3]]
+ x += self.pe[:, : x.shape[2], : x.shape[3]]
return x
-
-