summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/transducer/test.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/transducer/test.py')
-rw-r--r--src/text_recognizer/networks/transducer/test.py60
1 files changed, 60 insertions, 0 deletions
diff --git a/src/text_recognizer/networks/transducer/test.py b/src/text_recognizer/networks/transducer/test.py
new file mode 100644
index 0000000..cadcecc
--- /dev/null
+++ b/src/text_recognizer/networks/transducer/test.py
@@ -0,0 +1,60 @@
+import torch
+from torch import nn
+
+from text_recognizer.networks.transducer import load_transducer_loss, Transducer
+import unittest
+
+
+class TestTransducer(unittest.TestCase):
+ def test_viterbi(self):
+ T = 5
+ N = 4
+ B = 2
+
+ # fmt: off
+ emissions1 = torch.tensor((
+ 0, 4, 0, 1,
+ 0, 2, 1, 1,
+ 0, 0, 0, 2,
+ 0, 0, 0, 2,
+ 8, 0, 0, 2,
+ ),
+ dtype=torch.float,
+ ).view(T, N)
+ emissions2 = torch.tensor((
+ 0, 2, 1, 7,
+ 0, 2, 9, 1,
+ 0, 0, 0, 2,
+ 0, 0, 5, 2,
+ 1, 0, 0, 2,
+ ),
+ dtype=torch.float,
+ ).view(T, N)
+ # fmt: on
+
+ # Test without blank:
+ labels = [[1, 3, 0], [3, 2, 3, 2, 3]]
+ transducer = Transducer(
+ tokens=["a", "b", "c", "d"],
+ graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3},
+ blank="none",
+ )
+ emissions = torch.stack([emissions1, emissions2], dim=0)
+ predictions = transducer.viterbi(emissions)
+ self.assertEqual([p.tolist() for p in predictions], labels)
+
+ # Test with blank without repeats:
+ labels = [[1, 0], [2, 2]]
+ transducer = Transducer(
+ tokens=["a", "b", "c"],
+ graphemes_to_idx={"a": 0, "b": 1, "c": 2},
+ blank="optional",
+ allow_repeats=False,
+ )
+ emissions = torch.stack([emissions1, emissions2], dim=0)
+ predictions = transducer.viterbi(emissions)
+ self.assertEqual([p.tolist() for p in predictions], labels)
+
+
+if __name__ == "__main__":
+ unittest.main()