From 31e9673eef3088f08e3ee6aef8b78abd701ca329 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 4 Apr 2021 16:05:13 +0200 Subject: Reformat test for CER --- tests/test_cer.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 tests/test_cer.py (limited to 'tests/test_cer.py') diff --git a/tests/test_cer.py b/tests/test_cer.py new file mode 100644 index 0000000..30d58b2 --- /dev/null +++ b/tests/test_cer.py @@ -0,0 +1,23 @@ +"""Test the CER metric.""" +import torch + +from text_recognizer.models.metrics import CharacterErrorRate + + +def test_character_error_rate() -> None: + """Test CER computation.""" + metric = CharacterErrorRate([0, 1]) + preds = torch.Tensor( + [ + [0, 2, 2, 3, 3, 1], # error will be 0 + [0, 2, 1, 1, 1, 1], # error will be 0.75 + [0, 2, 2, 4, 4, 1], # error will be 0.5 + ] + ) + + targets = torch.Tensor([[0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1]]) + metric(preds, targets) + print(metric.compute()) + assert metric.compute() == float(sum([0, 0.75, 0.5]) / 3) + + -- cgit v1.2.3-70-g09d2