summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/transducer/test.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
commit7e8e54e84c63171e748bbf09516fd517e6821ace (patch)
tree996093f75a5d488dddf7ea1f159ed343a561ef89 /src/text_recognizer/networks/transducer/test.py
parentb0719d84138b6bbe5f04a4982dfca673aea1a368 (diff)
Inital commit for refactoring to lightning
Diffstat (limited to 'src/text_recognizer/networks/transducer/test.py')
-rw-r--r--src/text_recognizer/networks/transducer/test.py60
1 files changed, 0 insertions, 60 deletions
diff --git a/src/text_recognizer/networks/transducer/test.py b/src/text_recognizer/networks/transducer/test.py
deleted file mode 100644
index cadcecc..0000000
--- a/src/text_recognizer/networks/transducer/test.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import torch
-from torch import nn
-
-from text_recognizer.networks.transducer import load_transducer_loss, Transducer
-import unittest
-
-
-class TestTransducer(unittest.TestCase):
- def test_viterbi(self):
- T = 5
- N = 4
- B = 2
-
- # fmt: off
- emissions1 = torch.tensor((
- 0, 4, 0, 1,
- 0, 2, 1, 1,
- 0, 0, 0, 2,
- 0, 0, 0, 2,
- 8, 0, 0, 2,
- ),
- dtype=torch.float,
- ).view(T, N)
- emissions2 = torch.tensor((
- 0, 2, 1, 7,
- 0, 2, 9, 1,
- 0, 0, 0, 2,
- 0, 0, 5, 2,
- 1, 0, 0, 2,
- ),
- dtype=torch.float,
- ).view(T, N)
- # fmt: on
-
- # Test without blank:
- labels = [[1, 3, 0], [3, 2, 3, 2, 3]]
- transducer = Transducer(
- tokens=["a", "b", "c", "d"],
- graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3},
- blank="none",
- )
- emissions = torch.stack([emissions1, emissions2], dim=0)
- predictions = transducer.viterbi(emissions)
- self.assertEqual([p.tolist() for p in predictions], labels)
-
- # Test with blank without repeats:
- labels = [[1, 0], [2, 2]]
- transducer = Transducer(
- tokens=["a", "b", "c"],
- graphemes_to_idx={"a": 0, "b": 1, "c": 2},
- blank="optional",
- allow_repeats=False,
- )
- emissions = torch.stack([emissions1, emissions2], dim=0)
- predictions = transducer.viterbi(emissions)
- self.assertEqual([p.tolist() for p in predictions], labels)
-
-
-if __name__ == "__main__":
- unittest.main()