summaryrefslogtreecommitdiff
path: root/tests/test_cer.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 16:05:13 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 16:05:13 +0200
commit31e9673eef3088f08e3ee6aef8b78abd701ca329 (patch)
treef529d975d18d718a5d646e93f746d8be6f2f5cfe /tests/test_cer.py
parent36964354407d0fdf73bdca2f611fee1664860197 (diff)
Reformat test for CER
Diffstat (limited to 'tests/test_cer.py')
-rw-r--r--tests/test_cer.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/tests/test_cer.py b/tests/test_cer.py
new file mode 100644
index 0000000..30d58b2
--- /dev/null
+++ b/tests/test_cer.py
@@ -0,0 +1,23 @@
+"""Test the CER metric."""
+import torch
+
+from text_recognizer.models.metrics import CharacterErrorRate
+
+
+def test_character_error_rate() -> None:
+ """Test CER computation."""
+ metric = CharacterErrorRate([0, 1])
+ preds = torch.Tensor(
+ [
+ [0, 2, 2, 3, 3, 1], # error will be 0
+ [0, 2, 1, 1, 1, 1], # error will be 0.75
+ [0, 2, 2, 4, 4, 1], # error will be 0.5
+ ]
+ )
+
+ targets = torch.Tensor([[0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1]])
+ metric(preds, targets)
+ print(metric.compute())
+ assert metric.compute() == float(sum([0, 0.75, 0.5]) / 3)
+
+