From 31e9673eef3088f08e3ee6aef8b78abd701ca329 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sun, 4 Apr 2021 16:05:13 +0200
Subject: Reformat test for CER

---
 poetry.lock                                        | 42 +++++++++++++++++++++-
 pyproject.toml                                     |  2 ++
 tests/test_cer.py                                  | 23 ++++++++++++
 text_recognizer/data/iam_extended_paragraphs.py    |  1 +
 text_recognizer/models/base.py                     | 30 +++++++++++-----
 text_recognizer/models/metrics.py                  | 32 +++++++++++++++++
 .../networks/transformer/positional_encoding.py    | 16 +++++----
 7 files changed, 129 insertions(+), 17 deletions(-)
 create mode 100644 tests/test_cer.py
 create mode 100644 text_recognizer/models/metrics.py

diff --git a/poetry.lock b/poetry.lock
index f5b3fb0..3e5fcdf 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -340,6 +340,14 @@ toml = "*"
 [package.extras]
 pipenv = ["pipenv"]
 
+[[package]]
+name = "editdistance"
+version = "0.5.3"
+description = "Fast implementation of the edit distance(Levenshtein distance)"
+category = "main"
+optional = false
+python-versions = "*"
+
 [[package]]
 name = "einops"
 version = "0.3.0"
@@ -2188,7 +2196,7 @@ multidict = ">=4.0"
 [metadata]
 lock-version = "1.1"
 python-versions = "^3.8"
-content-hash = "db4253add1258abaf637f127ba576b1ec5e0b415c8f7f93b18ecdca40bc6f042"
+content-hash = "478242efb3644e920eb8bce6710d1f0d4548635b0e8c9f758f665814af7f8300"
 
 [metadata.files]
 absl-py = [
@@ -2456,6 +2464,38 @@ dparse = [
     {file = "dparse-0.5.1-py3-none-any.whl", hash = "sha256:e953a25e44ebb60a5c6efc2add4420c177f1d8404509da88da9729202f306994"},
     {file = "dparse-0.5.1.tar.gz", hash = "sha256:a1b5f169102e1c894f9a7d5ccf6f9402a836a5d24be80a986c7ce9eaed78f367"},
 ]
+editdistance = [
+    {file = "editdistance-0.5.3-cp27-cp27m-macosx_10_6_intel.whl", hash = "sha256:ef4714dc9cf281863dcc3ba6d24c3cae1dde41610a78dcdfae50d743ca71d5e1"},
+    {file = "editdistance-0.5.3-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:a322354a8dfb442770902f06552b20df5184e65e84ac90cb799740915eb52212"},
+    {file = "editdistance-0.5.3-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:36a4c36d7945f5ecfa1dc92c08635d73b64769cd0af066da774437fe2c7dc80a"},
+    {file = "editdistance-0.5.3-cp27-cp27m-win32.whl", hash = "sha256:93e847cc2fbebb34a36b41337a3eb9b2034d4ff9679665b08ecc5c3c313f83a9"},
+    {file = "editdistance-0.5.3-cp27-cp27m-win_amd64.whl", hash = "sha256:d4561b602b7675f6a050cdd0e1b652007ce73bb7290019487b8919a44593d74d"},
+    {file = "editdistance-0.5.3-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:dddb0d36f698e3c942d0d5934185533d9324fbde975b3e956a19883713e86d33"},
+    {file = "editdistance-0.5.3-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:1018f0fa857b079c721583c42d2c54800fbe8c7d2c29b354a9724a0b79971cb8"},
+    {file = "editdistance-0.5.3-cp33-cp33m-win32.whl", hash = "sha256:810d93e614f35ad2916570f48ff1370ac3c001eb6941d5e836e2c1c6986fafff"},
+    {file = "editdistance-0.5.3-cp33-cp33m-win_amd64.whl", hash = "sha256:a96ac49acc7668477c13aff02ca0527c6462b026b78600602dbef04efc9250d3"},
+    {file = "editdistance-0.5.3-cp34-cp34m-macosx_10_6_intel.whl", hash = "sha256:a9167d9d5e754abd7ce68da065a636cc161e5063c322efd81159d15001d5272a"},
+    {file = "editdistance-0.5.3-cp34-cp34m-manylinux1_i686.whl", hash = "sha256:a10c61df748220b2b9e2949a10aea23ffeded28c07e610e107a8f6a4b5b92782"},
+    {file = "editdistance-0.5.3-cp34-cp34m-manylinux1_x86_64.whl", hash = "sha256:6452d750fbc49c6f04232a840f96b0f1155ff7cb2d953ce1edf075c5a394f3ea"},
+    {file = "editdistance-0.5.3-cp34-cp34m-win32.whl", hash = "sha256:1f510e6eb411ec6123ba4ebc086d5882027710d28db174985a74e13fd0eb354f"},
+    {file = "editdistance-0.5.3-cp34-cp34m-win_amd64.whl", hash = "sha256:9d6ee66f8de30ec6358083e5ecd7919a5966b38c64012c1672f326c61ff7a15f"},
+    {file = "editdistance-0.5.3-cp35-cp35m-macosx_10_6_intel.whl", hash = "sha256:c1cf5ff98cfdc38046ae0f2d3ccbe1e15b0665234a04783f6558ec0a48e72dc8"},
+    {file = "editdistance-0.5.3-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:5f9c202b1a2f2630f7a0cdd76ad0ad55de4cd700553778c77e37379c6ac8e8bb"},
+    {file = "editdistance-0.5.3-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:553fb295802c399f0f419b616b499c241ffdcb2a70888d1e9d1bd22ba21b122f"},
+    {file = "editdistance-0.5.3-cp35-cp35m-win32.whl", hash = "sha256:0834826832e51a6c18032b13b68083e3ebfbf3daf774142ae6f2b17b35580c16"},
+    {file = "editdistance-0.5.3-cp35-cp35m-win_amd64.whl", hash = "sha256:6ccfd57221bae661304e7f9495f508aeec8f72e462d97481d55488ded87f5cbc"},
+    {file = "editdistance-0.5.3-cp36-cp36m-macosx_10_6_intel.whl", hash = "sha256:25b39c836347dcbb251a6041fd3d7575b82c365923a4b13c32c699e442b1b644"},
+    {file = "editdistance-0.5.3-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:fa0047a8d972ab779141eed4713811251c9f6e96e9e8a62caa8d554a0444ff74"},
+    {file = "editdistance-0.5.3-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:db65bf1f39964019040434cb924c62c9965bd0df2feb316dbe5de3f09e6a81de"},
+    {file = "editdistance-0.5.3-cp36-cp36m-win32.whl", hash = "sha256:cc65c2cd68751a966f7468537b4a6fd7d9107d49e139d8efd5734ee6f48d3126"},
+    {file = "editdistance-0.5.3-cp36-cp36m-win_amd64.whl", hash = "sha256:fe7e6a90476976d7e5abc9472acb0311b7cdc76d84190f8f6c317234680c5de3"},
+    {file = "editdistance-0.5.3-cp37-cp37m-macosx_10_6_intel.whl", hash = "sha256:25dd59d7f17a38203c5e433f5b11f64a8d1042d876d0dc00b324dda060d12e81"},
+    {file = "editdistance-0.5.3-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:61486173447a153cccbd52eb63947378803f0f2a5bffebbfec500bd77fc5706d"},
+    {file = "editdistance-0.5.3-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:cd49e9b22972b15527d53e06918c14d9fe228ae362a57476d16b0cad3e14e0c8"},
+    {file = "editdistance-0.5.3-cp37-cp37m-win32.whl", hash = "sha256:503c6f69f4901d8a63f3748e4b0eccb2a89e6844b0879a7e256cab439297d379"},
+    {file = "editdistance-0.5.3-cp37-cp37m-win_amd64.whl", hash = "sha256:ee4ed815bc5137a794095368580334e430ff26c73a05c67e76b39f535b363a0f"},
+    {file = "editdistance-0.5.3.tar.gz", hash = "sha256:89d016dda04649b2c49e12b34337755a7b612bfd690420edd50ab31787120c1f"},
+]
 einops = [
     {file = "einops-0.3.0-py2.py3-none-any.whl", hash = "sha256:a91c6190ceff7d513d74ca9fd701dfa6a1ffcdd98ea0ced14350197c07f75c73"},
     {file = "einops-0.3.0.tar.gz", hash = "sha256:a3b0935a4556f012cd5fa1851373f63366890a3f6698d117afea55fd2a40c1fc"},
diff --git a/pyproject.toml b/pyproject.toml
index 252678f..abf69ed 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -39,6 +39,8 @@ sentencepiece = "^0.1.95"
 pytorch-lightning = "^1.2.4"
 Pillow = "^8.1.2"
 madgrad = "^1.0"
+editdistance = "^0.5.3"
+torchmetrics = "^0.2.0"
 
 [tool.poetry.dev-dependencies]
 pytest = "^5.4.2"
diff --git a/tests/test_cer.py b/tests/test_cer.py
new file mode 100644
index 0000000..30d58b2
--- /dev/null
+++ b/tests/test_cer.py
@@ -0,0 +1,23 @@
+"""Test the CER metric."""
+import torch
+
+from text_recognizer.models.metrics import CharacterErrorRate
+
+
+def test_character_error_rate() -> None:
+    """Test CER computation."""
+    metric = CharacterErrorRate([0, 1])
+    preds = torch.Tensor(
+        [
+            [0, 2, 2, 3, 3, 1],  # error will be 0
+            [0, 2, 1, 1, 1, 1],  # error will be 0.75
+            [0, 2, 2, 4, 4, 1],  # error will be 0.5
+        ]
+    )
+
+    targets = torch.Tensor([[0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1]])
+    metric(preds, targets)
+    print(metric.compute())
+    assert metric.compute() == float(sum([0, 0.75, 0.5]) / 3)
+
+
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 51050fc..d2529b4 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -71,5 +71,6 @@ class IAMExtendedParagraphs(BaseDataModule):
         )
         return basic + data
 
+
 def show_dataset_info() -> None:
     load_and_print_info(IAMExtendedParagraphs)
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index e86b478..0c70625 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -6,6 +6,7 @@ import pytorch_lightning as pl
 import torch
 from torch import nn
 from torch import Tensor
+import torchmetrics
 
 from text_recognizer import networks
 
@@ -13,7 +14,14 @@ from text_recognizer import networks
 class BaseModel(pl.LightningModule):
     """Abstract PyTorch Lightning class."""
 
-    def __init__(self, network_args: Dict, optimizer_args: Dict, lr_scheduler_args: Dict, criterion_args: Dict, monitor: str = "val_loss") -> None:
+    def __init__(
+        self,
+        network_args: Dict,
+        optimizer_args: Dict,
+        lr_scheduler_args: Dict,
+        criterion_args: Dict,
+        monitor: str = "val_loss",
+    ) -> None:
         super().__init__()
         self.monitor = monitor
         self.network = getattr(networks, network_args["type"])(**network_args["args"])
@@ -22,9 +30,9 @@ class BaseModel(pl.LightningModule):
         self.loss_fn = self.configure_criterion(criterion_args)
 
         # Accuracy metric
-        self.train_acc = pl.metrics.Accuracy()
-        self.val_acc = pl.metrics.Accuracy()
-        self.test_acc = pl.metrics.Accuracy()
+        self.train_acc = torchmetrics.Accuracy()
+        self.val_acc = torchmetrics.Accuracy()
+        self.test_acc = torchmetrics.Accuracy()
 
     @staticmethod
     def configure_criterion(criterion_args: Dict) -> Type[nn.Module]:
@@ -41,8 +49,14 @@ class BaseModel(pl.LightningModule):
             optimizer = getattr(torch.optim, self.optimizer_args["type"])(**args)
 
         args = {} or self.lr_scheduler_args["args"]
-        scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])(**args)
-        return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": self.monitor}
+        scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])(
+            **args
+        )
+        return {
+            "optimizer": optimizer,
+            "lr_scheduler": scheduler,
+            "monitor": self.monitor,
+        }
 
     def forward(self, data: Tensor) -> Tensor:
         """Feedforward pass."""
@@ -55,7 +69,7 @@ class BaseModel(pl.LightningModule):
         loss = self.loss_fn(logits, targets)
         self.log("train_loss", loss)
         self.train_acc(logits, targets)
-        self.log("train_acc": self.train_acc, on_step=False, on_epoch=True)
+        self.log("train_acc", self.train_acc, on_step=False, on_epoch=True)
         return loss
 
     def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
@@ -73,5 +87,3 @@ class BaseModel(pl.LightningModule):
         logits = self(data)
         self.test_acc(logits, targets)
         self.log("test_acc", self.test_acc, on_step=False, on_epoch=True)
-
-
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py
new file mode 100644
index 0000000..58d0537
--- /dev/null
+++ b/text_recognizer/models/metrics.py
@@ -0,0 +1,32 @@
+"""Character Error Rate (CER)."""
+from typing import Sequence
+
+import editdistance
+import torch
+from torch import Tensor
+import torchmetrics
+
+
+class CharacterErrorRate(torchmetrics.Metric):
+    """Character error rate metric, computed using Levenshtein distance."""
+
+    def __init__(self, ignore_tokens: Sequence[int], *args) -> None:
+        super().__init__()
+        self.ignore_tokens = set(ignore_tokens)
+        self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum")
+        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
+
+    def update(self, preds: Tensor, targets: Tensor) -> None:
+        """Update CER."""
+        bsz = preds.shape[0]
+        for index in range(bsz):
+            pred = [p for p in preds[index].tolist() if p not in self.ignore_tokens]
+            target = [t for t in targets[index].tolist() if t not in self.ignore_tokens]
+            distance = editdistance.distance(pred, target)
+            error = distance / max(len(pred), len(target))
+            self.error += error
+        self.total += bsz
+
+    def compute(self) -> Tensor:
+        """Compute CER."""
+        return self.error / self.total
diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py
index d03f630..d67d297 100644
--- a/text_recognizer/networks/transformer/positional_encoding.py
+++ b/text_recognizer/networks/transformer/positional_encoding.py
@@ -16,7 +16,7 @@ class PositionalEncoding(nn.Module):
         self.dropout = nn.Dropout(p=dropout_rate)
         pe = self.make_pe(hidden_dim, max_len)
         self.register_buffer("pe", pe)
-    
+
     @staticmethod
     def make_pe(hidden_dim: int, max_len: int) -> Tensor:
         """Returns positional encoding."""
@@ -40,7 +40,7 @@ class PositionalEncoding(nn.Module):
 class PositionalEncoding2D(nn.Module):
     """Positional encodings for feature maps."""
 
-    def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int =2048) -> None:
+    def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int = 2048) -> None:
         super().__init__()
         if hidden_dim % 2 != 0:
             raise ValueError(f"Embedding depth {hidden_dim} is not even!")
@@ -50,10 +50,14 @@ class PositionalEncoding2D(nn.Module):
 
     def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor:
         """Returns 2d postional encoding."""
-        pe_h = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h)  # [H, 1, D // 2]
+        pe_h = PositionalEncoding.make_pe(
+            hidden_dim // 2, max_len=max_h
+        )  # [H, 1, D // 2]
         pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w)
 
-        pe_w = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h)  # [W, 1, D // 2]
+        pe_w = PositionalEncoding.make_pe(
+            hidden_dim // 2, max_len=max_h
+        )  # [W, 1, D // 2]
         pe_w = repeat(pe_w, "h w d -> d (h tile) w", tile=max_h)
 
         pe = torch.cat([pe_h, pe_w], dim=0)  # [D, H, W]
@@ -64,7 +68,5 @@ class PositionalEncoding2D(nn.Module):
         # Assumes x hase shape [B, D, H, W]
         if x.shape[1] != self.pe.shape[0]:
             raise ValueError("Hidden dimensions does not match.")
-        x += self.pe[:, :x.shape[2], :x.shape[3]]
+        x += self.pe[:, : x.shape[2], : x.shape[3]]
         return x
-
-
-- 
cgit v1.2.3-70-g09d2