diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-06-06 23:19:35 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-06-06 23:19:35 +0200 |
commit | 01d6e5fc066969283df99c759609df441151e9c5 (patch) | |
tree | ecd1459e142356d0c7f50a61307b760aca813248 /text_recognizer/networks/transducer/test.py | |
parent | f4688482b4898c0b342d6ae59839dc27fbf856c6 (diff) |
Working on fixing decoder transformer
Diffstat (limited to 'text_recognizer/networks/transducer/test.py')
-rw-r--r-- | text_recognizer/networks/transducer/test.py | 60 |
1 files changed, 0 insertions, 60 deletions
diff --git a/text_recognizer/networks/transducer/test.py b/text_recognizer/networks/transducer/test.py deleted file mode 100644 index cadcecc..0000000 --- a/text_recognizer/networks/transducer/test.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch -from torch import nn - -from text_recognizer.networks.transducer import load_transducer_loss, Transducer -import unittest - - -class TestTransducer(unittest.TestCase): - def test_viterbi(self): - T = 5 - N = 4 - B = 2 - - # fmt: off - emissions1 = torch.tensor(( - 0, 4, 0, 1, - 0, 2, 1, 1, - 0, 0, 0, 2, - 0, 0, 0, 2, - 8, 0, 0, 2, - ), - dtype=torch.float, - ).view(T, N) - emissions2 = torch.tensor(( - 0, 2, 1, 7, - 0, 2, 9, 1, - 0, 0, 0, 2, - 0, 0, 5, 2, - 1, 0, 0, 2, - ), - dtype=torch.float, - ).view(T, N) - # fmt: on - - # Test without blank: - labels = [[1, 3, 0], [3, 2, 3, 2, 3]] - transducer = Transducer( - tokens=["a", "b", "c", "d"], - graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3}, - blank="none", - ) - emissions = torch.stack([emissions1, emissions2], dim=0) - predictions = transducer.viterbi(emissions) - self.assertEqual([p.tolist() for p in predictions], labels) - - # Test with blank without repeats: - labels = [[1, 0], [2, 2]] - transducer = Transducer( - tokens=["a", "b", "c"], - graphemes_to_idx={"a": 0, "b": 1, "c": 2}, - blank="optional", - allow_repeats=False, - ) - emissions = torch.stack([emissions1, emissions2], dim=0) - predictions = transducer.viterbi(emissions) - self.assertEqual([p.tolist() for p in predictions], labels) - - -if __name__ == "__main__": - unittest.main() |