diff options
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_cer.py | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/tests/test_cer.py b/tests/test_cer.py new file mode 100644 index 0000000..30d58b2 --- /dev/null +++ b/tests/test_cer.py @@ -0,0 +1,23 @@ +"""Test the CER metric.""" +import torch + +from text_recognizer.models.metrics import CharacterErrorRate + + +def test_character_error_rate() -> None: + """Test CER computation.""" + metric = CharacterErrorRate([0, 1]) + preds = torch.Tensor( + [ + [0, 2, 2, 3, 3, 1], # error will be 0 + [0, 2, 1, 1, 1, 1], # error will be 0.75 + [0, 2, 2, 4, 4, 1], # error will be 0.5 + ] + ) + + targets = torch.Tensor([[0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1]]) + metric(preds, targets) + print(metric.compute()) + assert metric.compute() == float(sum([0, 0.75, 0.5]) / 3) + + |