summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 16:05:13 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 16:05:13 +0200
commit31e9673eef3088f08e3ee6aef8b78abd701ca329 (patch)
treef529d975d18d718a5d646e93f746d8be6f2f5cfe
parent36964354407d0fdf73bdca2f611fee1664860197 (diff)
Reformat test for CER
-rw-r--r--poetry.lock42
-rw-r--r--pyproject.toml2
-rw-r--r--tests/test_cer.py23
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py1
-rw-r--r--text_recognizer/models/base.py30
-rw-r--r--text_recognizer/models/metrics.py32
-rw-r--r--text_recognizer/networks/transformer/positional_encoding.py16
7 files changed, 129 insertions, 17 deletions
diff --git a/poetry.lock b/poetry.lock
index f5b3fb0..3e5fcdf 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -341,6 +341,14 @@ toml = "*"
pipenv = ["pipenv"]
[[package]]
+name = "editdistance"
+version = "0.5.3"
+description = "Fast implementation of the edit distance(Levenshtein distance)"
+category = "main"
+optional = false
+python-versions = "*"
+
+[[package]]
name = "einops"
version = "0.3.0"
description = "A new flavour of deep learning operations"
@@ -2188,7 +2196,7 @@ multidict = ">=4.0"
[metadata]
lock-version = "1.1"
python-versions = "^3.8"
-content-hash = "db4253add1258abaf637f127ba576b1ec5e0b415c8f7f93b18ecdca40bc6f042"
+content-hash = "478242efb3644e920eb8bce6710d1f0d4548635b0e8c9f758f665814af7f8300"
[metadata.files]
absl-py = [
@@ -2456,6 +2464,38 @@ dparse = [
{file = "dparse-0.5.1-py3-none-any.whl", hash = "sha256:e953a25e44ebb60a5c6efc2add4420c177f1d8404509da88da9729202f306994"},
{file = "dparse-0.5.1.tar.gz", hash = "sha256:a1b5f169102e1c894f9a7d5ccf6f9402a836a5d24be80a986c7ce9eaed78f367"},
]
+editdistance = [
+ {file = "editdistance-0.5.3-cp27-cp27m-macosx_10_6_intel.whl", hash = "sha256:ef4714dc9cf281863dcc3ba6d24c3cae1dde41610a78dcdfae50d743ca71d5e1"},
+ {file = "editdistance-0.5.3-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:a322354a8dfb442770902f06552b20df5184e65e84ac90cb799740915eb52212"},
+ {file = "editdistance-0.5.3-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:36a4c36d7945f5ecfa1dc92c08635d73b64769cd0af066da774437fe2c7dc80a"},
+ {file = "editdistance-0.5.3-cp27-cp27m-win32.whl", hash = "sha256:93e847cc2fbebb34a36b41337a3eb9b2034d4ff9679665b08ecc5c3c313f83a9"},
+ {file = "editdistance-0.5.3-cp27-cp27m-win_amd64.whl", hash = "sha256:d4561b602b7675f6a050cdd0e1b652007ce73bb7290019487b8919a44593d74d"},
+ {file = "editdistance-0.5.3-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:dddb0d36f698e3c942d0d5934185533d9324fbde975b3e956a19883713e86d33"},
+ {file = "editdistance-0.5.3-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:1018f0fa857b079c721583c42d2c54800fbe8c7d2c29b354a9724a0b79971cb8"},
+ {file = "editdistance-0.5.3-cp33-cp33m-win32.whl", hash = "sha256:810d93e614f35ad2916570f48ff1370ac3c001eb6941d5e836e2c1c6986fafff"},
+ {file = "editdistance-0.5.3-cp33-cp33m-win_amd64.whl", hash = "sha256:a96ac49acc7668477c13aff02ca0527c6462b026b78600602dbef04efc9250d3"},
+ {file = "editdistance-0.5.3-cp34-cp34m-macosx_10_6_intel.whl", hash = "sha256:a9167d9d5e754abd7ce68da065a636cc161e5063c322efd81159d15001d5272a"},
+ {file = "editdistance-0.5.3-cp34-cp34m-manylinux1_i686.whl", hash = "sha256:a10c61df748220b2b9e2949a10aea23ffeded28c07e610e107a8f6a4b5b92782"},
+ {file = "editdistance-0.5.3-cp34-cp34m-manylinux1_x86_64.whl", hash = "sha256:6452d750fbc49c6f04232a840f96b0f1155ff7cb2d953ce1edf075c5a394f3ea"},
+ {file = "editdistance-0.5.3-cp34-cp34m-win32.whl", hash = "sha256:1f510e6eb411ec6123ba4ebc086d5882027710d28db174985a74e13fd0eb354f"},
+ {file = "editdistance-0.5.3-cp34-cp34m-win_amd64.whl", hash = "sha256:9d6ee66f8de30ec6358083e5ecd7919a5966b38c64012c1672f326c61ff7a15f"},
+ {file = "editdistance-0.5.3-cp35-cp35m-macosx_10_6_intel.whl", hash = "sha256:c1cf5ff98cfdc38046ae0f2d3ccbe1e15b0665234a04783f6558ec0a48e72dc8"},
+ {file = "editdistance-0.5.3-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:5f9c202b1a2f2630f7a0cdd76ad0ad55de4cd700553778c77e37379c6ac8e8bb"},
+ {file = "editdistance-0.5.3-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:553fb295802c399f0f419b616b499c241ffdcb2a70888d1e9d1bd22ba21b122f"},
+ {file = "editdistance-0.5.3-cp35-cp35m-win32.whl", hash = "sha256:0834826832e51a6c18032b13b68083e3ebfbf3daf774142ae6f2b17b35580c16"},
+ {file = "editdistance-0.5.3-cp35-cp35m-win_amd64.whl", hash = "sha256:6ccfd57221bae661304e7f9495f508aeec8f72e462d97481d55488ded87f5cbc"},
+ {file = "editdistance-0.5.3-cp36-cp36m-macosx_10_6_intel.whl", hash = "sha256:25b39c836347dcbb251a6041fd3d7575b82c365923a4b13c32c699e442b1b644"},
+ {file = "editdistance-0.5.3-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:fa0047a8d972ab779141eed4713811251c9f6e96e9e8a62caa8d554a0444ff74"},
+ {file = "editdistance-0.5.3-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:db65bf1f39964019040434cb924c62c9965bd0df2feb316dbe5de3f09e6a81de"},
+ {file = "editdistance-0.5.3-cp36-cp36m-win32.whl", hash = "sha256:cc65c2cd68751a966f7468537b4a6fd7d9107d49e139d8efd5734ee6f48d3126"},
+ {file = "editdistance-0.5.3-cp36-cp36m-win_amd64.whl", hash = "sha256:fe7e6a90476976d7e5abc9472acb0311b7cdc76d84190f8f6c317234680c5de3"},
+ {file = "editdistance-0.5.3-cp37-cp37m-macosx_10_6_intel.whl", hash = "sha256:25dd59d7f17a38203c5e433f5b11f64a8d1042d876d0dc00b324dda060d12e81"},
+ {file = "editdistance-0.5.3-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:61486173447a153cccbd52eb63947378803f0f2a5bffebbfec500bd77fc5706d"},
+ {file = "editdistance-0.5.3-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:cd49e9b22972b15527d53e06918c14d9fe228ae362a57476d16b0cad3e14e0c8"},
+ {file = "editdistance-0.5.3-cp37-cp37m-win32.whl", hash = "sha256:503c6f69f4901d8a63f3748e4b0eccb2a89e6844b0879a7e256cab439297d379"},
+ {file = "editdistance-0.5.3-cp37-cp37m-win_amd64.whl", hash = "sha256:ee4ed815bc5137a794095368580334e430ff26c73a05c67e76b39f535b363a0f"},
+ {file = "editdistance-0.5.3.tar.gz", hash = "sha256:89d016dda04649b2c49e12b34337755a7b612bfd690420edd50ab31787120c1f"},
+]
einops = [
{file = "einops-0.3.0-py2.py3-none-any.whl", hash = "sha256:a91c6190ceff7d513d74ca9fd701dfa6a1ffcdd98ea0ced14350197c07f75c73"},
{file = "einops-0.3.0.tar.gz", hash = "sha256:a3b0935a4556f012cd5fa1851373f63366890a3f6698d117afea55fd2a40c1fc"},
diff --git a/pyproject.toml b/pyproject.toml
index 252678f..abf69ed 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -39,6 +39,8 @@ sentencepiece = "^0.1.95"
pytorch-lightning = "^1.2.4"
Pillow = "^8.1.2"
madgrad = "^1.0"
+editdistance = "^0.5.3"
+torchmetrics = "^0.2.0"
[tool.poetry.dev-dependencies]
pytest = "^5.4.2"
diff --git a/tests/test_cer.py b/tests/test_cer.py
new file mode 100644
index 0000000..30d58b2
--- /dev/null
+++ b/tests/test_cer.py
@@ -0,0 +1,23 @@
+"""Test the CER metric."""
+import torch
+
+from text_recognizer.models.metrics import CharacterErrorRate
+
+
+def test_character_error_rate() -> None:
+ """Test CER computation."""
+ metric = CharacterErrorRate([0, 1])
+ preds = torch.Tensor(
+ [
+ [0, 2, 2, 3, 3, 1], # error will be 0
+ [0, 2, 1, 1, 1, 1], # error will be 0.75
+ [0, 2, 2, 4, 4, 1], # error will be 0.5
+ ]
+ )
+
+ targets = torch.Tensor([[0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1]])
+ metric(preds, targets)
+ print(metric.compute())
+ assert metric.compute() == float(sum([0, 0.75, 0.5]) / 3)
+
+
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 51050fc..d2529b4 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -71,5 +71,6 @@ class IAMExtendedParagraphs(BaseDataModule):
)
return basic + data
+
def show_dataset_info() -> None:
load_and_print_info(IAMExtendedParagraphs)
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index e86b478..0c70625 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -6,6 +6,7 @@ import pytorch_lightning as pl
import torch
from torch import nn
from torch import Tensor
+import torchmetrics
from text_recognizer import networks
@@ -13,7 +14,14 @@ from text_recognizer import networks
class BaseModel(pl.LightningModule):
"""Abstract PyTorch Lightning class."""
- def __init__(self, network_args: Dict, optimizer_args: Dict, lr_scheduler_args: Dict, criterion_args: Dict, monitor: str = "val_loss") -> None:
+ def __init__(
+ self,
+ network_args: Dict,
+ optimizer_args: Dict,
+ lr_scheduler_args: Dict,
+ criterion_args: Dict,
+ monitor: str = "val_loss",
+ ) -> None:
super().__init__()
self.monitor = monitor
self.network = getattr(networks, network_args["type"])(**network_args["args"])
@@ -22,9 +30,9 @@ class BaseModel(pl.LightningModule):
self.loss_fn = self.configure_criterion(criterion_args)
# Accuracy metric
- self.train_acc = pl.metrics.Accuracy()
- self.val_acc = pl.metrics.Accuracy()
- self.test_acc = pl.metrics.Accuracy()
+ self.train_acc = torchmetrics.Accuracy()
+ self.val_acc = torchmetrics.Accuracy()
+ self.test_acc = torchmetrics.Accuracy()
@staticmethod
def configure_criterion(criterion_args: Dict) -> Type[nn.Module]:
@@ -41,8 +49,14 @@ class BaseModel(pl.LightningModule):
optimizer = getattr(torch.optim, self.optimizer_args["type"])(**args)
args = {} or self.lr_scheduler_args["args"]
- scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])(**args)
- return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": self.monitor}
+ scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])(
+ **args
+ )
+ return {
+ "optimizer": optimizer,
+ "lr_scheduler": scheduler,
+ "monitor": self.monitor,
+ }
def forward(self, data: Tensor) -> Tensor:
"""Feedforward pass."""
@@ -55,7 +69,7 @@ class BaseModel(pl.LightningModule):
loss = self.loss_fn(logits, targets)
self.log("train_loss", loss)
self.train_acc(logits, targets)
- self.log("train_acc": self.train_acc, on_step=False, on_epoch=True)
+ self.log("train_acc", self.train_acc, on_step=False, on_epoch=True)
return loss
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
@@ -73,5 +87,3 @@ class BaseModel(pl.LightningModule):
logits = self(data)
self.test_acc(logits, targets)
self.log("test_acc", self.test_acc, on_step=False, on_epoch=True)
-
-
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py
new file mode 100644
index 0000000..58d0537
--- /dev/null
+++ b/text_recognizer/models/metrics.py
@@ -0,0 +1,32 @@
+"""Character Error Rate (CER)."""
+from typing import Sequence
+
+import editdistance
+import torch
+from torch import Tensor
+import torchmetrics
+
+
+class CharacterErrorRate(torchmetrics.Metric):
+ """Character error rate metric, computed using Levenshtein distance."""
+
+ def __init__(self, ignore_tokens: Sequence[int], *args) -> None:
+ super().__init__()
+ self.ignore_tokens = set(ignore_tokens)
+ self.add_state("error", default=torch.tensor(0.0), dist_reduce_fx="sum")
+ self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
+
+ def update(self, preds: Tensor, targets: Tensor) -> None:
+ """Update CER."""
+ bsz = preds.shape[0]
+ for index in range(bsz):
+ pred = [p for p in preds[index].tolist() if p not in self.ignore_tokens]
+ target = [t for t in targets[index].tolist() if t not in self.ignore_tokens]
+ distance = editdistance.distance(pred, target)
+ error = distance / max(len(pred), len(target))
+ self.error += error
+ self.total += bsz
+
+ def compute(self) -> Tensor:
+ """Compute CER."""
+ return self.error / self.total
diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py
index d03f630..d67d297 100644
--- a/text_recognizer/networks/transformer/positional_encoding.py
+++ b/text_recognizer/networks/transformer/positional_encoding.py
@@ -16,7 +16,7 @@ class PositionalEncoding(nn.Module):
self.dropout = nn.Dropout(p=dropout_rate)
pe = self.make_pe(hidden_dim, max_len)
self.register_buffer("pe", pe)
-
+
@staticmethod
def make_pe(hidden_dim: int, max_len: int) -> Tensor:
"""Returns positional encoding."""
@@ -40,7 +40,7 @@ class PositionalEncoding(nn.Module):
class PositionalEncoding2D(nn.Module):
"""Positional encodings for feature maps."""
- def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int =2048) -> None:
+ def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int = 2048) -> None:
super().__init__()
if hidden_dim % 2 != 0:
raise ValueError(f"Embedding depth {hidden_dim} is not even!")
@@ -50,10 +50,14 @@ class PositionalEncoding2D(nn.Module):
def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor:
"""Returns 2d postional encoding."""
- pe_h = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h) # [H, 1, D // 2]
+ pe_h = PositionalEncoding.make_pe(
+ hidden_dim // 2, max_len=max_h
+ ) # [H, 1, D // 2]
pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w)
- pe_w = PositionalEncoding.make_pe(hidden_dim // 2, max_len=max_h) # [W, 1, D // 2]
+ pe_w = PositionalEncoding.make_pe(
+ hidden_dim // 2, max_len=max_h
+ ) # [W, 1, D // 2]
pe_w = repeat(pe_w, "h w d -> d (h tile) w", tile=max_h)
pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W]
@@ -64,7 +68,5 @@ class PositionalEncoding2D(nn.Module):
# Assumes x hase shape [B, D, H, W]
if x.shape[1] != self.pe.shape[0]:
raise ValueError("Hidden dimensions does not match.")
- x += self.pe[:, :x.shape[2], :x.shape[3]]
+ x += self.pe[:, : x.shape[2], : x.shape[3]]
return x
-
-