diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
commit | 7e8e54e84c63171e748bbf09516fd517e6821ace (patch) | |
tree | 996093f75a5d488dddf7ea1f159ed343a561ef89 /text_recognizer/networks/transducer/test.py | |
parent | b0719d84138b6bbe5f04a4982dfca673aea1a368 (diff) |
Inital commit for refactoring to lightning
Diffstat (limited to 'text_recognizer/networks/transducer/test.py')
-rw-r--r-- | text_recognizer/networks/transducer/test.py | 60 |
1 files changed, 60 insertions, 0 deletions
diff --git a/text_recognizer/networks/transducer/test.py b/text_recognizer/networks/transducer/test.py new file mode 100644 index 0000000..cadcecc --- /dev/null +++ b/text_recognizer/networks/transducer/test.py @@ -0,0 +1,60 @@ +import torch +from torch import nn + +from text_recognizer.networks.transducer import load_transducer_loss, Transducer +import unittest + + +class TestTransducer(unittest.TestCase): + def test_viterbi(self): + T = 5 + N = 4 + B = 2 + + # fmt: off + emissions1 = torch.tensor(( + 0, 4, 0, 1, + 0, 2, 1, 1, + 0, 0, 0, 2, + 0, 0, 0, 2, + 8, 0, 0, 2, + ), + dtype=torch.float, + ).view(T, N) + emissions2 = torch.tensor(( + 0, 2, 1, 7, + 0, 2, 9, 1, + 0, 0, 0, 2, + 0, 0, 5, 2, + 1, 0, 0, 2, + ), + dtype=torch.float, + ).view(T, N) + # fmt: on + + # Test without blank: + labels = [[1, 3, 0], [3, 2, 3, 2, 3]] + transducer = Transducer( + tokens=["a", "b", "c", "d"], + graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3}, + blank="none", + ) + emissions = torch.stack([emissions1, emissions2], dim=0) + predictions = transducer.viterbi(emissions) + self.assertEqual([p.tolist() for p in predictions], labels) + + # Test with blank without repeats: + labels = [[1, 0], [2, 2]] + transducer = Transducer( + tokens=["a", "b", "c"], + graphemes_to_idx={"a": 0, "b": 1, "c": 2}, + blank="optional", + allow_repeats=False, + ) + emissions = torch.stack([emissions1, emissions2], dim=0) + predictions = transducer.viterbi(emissions) + self.assertEqual([p.tolist() for p in predictions], labels) + + +if __name__ == "__main__": + unittest.main() |