summaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/test_cer.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/tests/test_cer.py b/tests/test_cer.py
new file mode 100644
index 0000000..30d58b2
--- /dev/null
+++ b/tests/test_cer.py
@@ -0,0 +1,23 @@
+"""Test the CER metric."""
+import torch
+
+from text_recognizer.models.metrics import CharacterErrorRate
+
+
+def test_character_error_rate() -> None:
+ """Test CER computation."""
+ metric = CharacterErrorRate([0, 1])
+ preds = torch.Tensor(
+ [
+ [0, 2, 2, 3, 3, 1], # error will be 0
+ [0, 2, 1, 1, 1, 1], # error will be 0.75
+ [0, 2, 2, 4, 4, 1], # error will be 0.5
+ ]
+ )
+
+ targets = torch.Tensor([[0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1]])
+ metric(preds, targets)
+ print(metric.compute())
+ assert metric.compute() == float(sum([0, 0.75, 0.5]) / 3)
+
+