diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2021-02-24 22:00:29 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2021-02-24 22:00:29 +0100 |
commit | 905eeeb4c3c0ba54b5414eb8f435e2e9870b7307 (patch) | |
tree | 91dab598a94911e6147b996237e786dd47f11f2f /src/text_recognizer/networks/transducer/test.py | |
parent | 4a54d7e690897dd6e6c719fb908fd371a44c2952 (diff) |
updates
Diffstat (limited to 'src/text_recognizer/networks/transducer/test.py')
-rw-r--r-- | src/text_recognizer/networks/transducer/test.py | 60 |
1 files changed, 60 insertions, 0 deletions
diff --git a/src/text_recognizer/networks/transducer/test.py b/src/text_recognizer/networks/transducer/test.py new file mode 100644 index 0000000..cadcecc --- /dev/null +++ b/src/text_recognizer/networks/transducer/test.py @@ -0,0 +1,60 @@ +import torch +from torch import nn + +from text_recognizer.networks.transducer import load_transducer_loss, Transducer +import unittest + + +class TestTransducer(unittest.TestCase): + def test_viterbi(self): + T = 5 + N = 4 + B = 2 + + # fmt: off + emissions1 = torch.tensor(( + 0, 4, 0, 1, + 0, 2, 1, 1, + 0, 0, 0, 2, + 0, 0, 0, 2, + 8, 0, 0, 2, + ), + dtype=torch.float, + ).view(T, N) + emissions2 = torch.tensor(( + 0, 2, 1, 7, + 0, 2, 9, 1, + 0, 0, 0, 2, + 0, 0, 5, 2, + 1, 0, 0, 2, + ), + dtype=torch.float, + ).view(T, N) + # fmt: on + + # Test without blank: + labels = [[1, 3, 0], [3, 2, 3, 2, 3]] + transducer = Transducer( + tokens=["a", "b", "c", "d"], + graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3}, + blank="none", + ) + emissions = torch.stack([emissions1, emissions2], dim=0) + predictions = transducer.viterbi(emissions) + self.assertEqual([p.tolist() for p in predictions], labels) + + # Test with blank without repeats: + labels = [[1, 0], [2, 2]] + transducer = Transducer( + tokens=["a", "b", "c"], + graphemes_to_idx={"a": 0, "b": 1, "c": 2}, + blank="optional", + allow_repeats=False, + ) + emissions = torch.stack([emissions1, emissions2], dim=0) + predictions = transducer.viterbi(emissions) + self.assertEqual([p.tolist() for p in predictions], labels) + + +if __name__ == "__main__": + unittest.main() |