diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-04 16:05:13 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-04 16:05:13 +0200 |
commit | 31e9673eef3088f08e3ee6aef8b78abd701ca329 (patch) | |
tree | f529d975d18d718a5d646e93f746d8be6f2f5cfe /tests | |
parent | 36964354407d0fdf73bdca2f611fee1664860197 (diff) |
Reformat test for CER
Diffstat (limited to 'tests')
-rw-r--r-- | tests/test_cer.py | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/tests/test_cer.py b/tests/test_cer.py new file mode 100644 index 0000000..30d58b2 --- /dev/null +++ b/tests/test_cer.py @@ -0,0 +1,23 @@ +"""Test the CER metric.""" +import torch + +from text_recognizer.models.metrics import CharacterErrorRate + + +def test_character_error_rate() -> None: + """Test CER computation.""" + metric = CharacterErrorRate([0, 1]) + preds = torch.Tensor( + [ + [0, 2, 2, 3, 3, 1], # error will be 0 + [0, 2, 1, 1, 1, 1], # error will be 0.75 + [0, 2, 2, 4, 4, 1], # error will be 0.5 + ] + ) + + targets = torch.Tensor([[0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1], [0, 2, 2, 3, 3, 1]]) + metric(preds, targets) + print(metric.compute()) + assert metric.compute() == float(sum([0, 0.75, 0.5]) / 3) + + |