diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-12-07 22:54:04 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-12-07 22:54:04 +0100 |
commit | 25b5d6983d51e0e791b96a76beb7e49f392cd9a8 (patch) | |
tree | 526ba739714b3d040f7810c1a6be3ff0ba37fdb1 /src/text_recognizer/tests | |
parent | 5529e0fc9ca39e81fe0f08a54f257d32f0afe120 (diff) |
Segmentation working!
Diffstat (limited to 'src/text_recognizer/tests')
-rw-r--r-- | src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg | bin | 0 -> 14890 bytes | |||
-rw-r--r-- | src/text_recognizer/tests/test_paragraph_text_recognizer.py | 37 |
2 files changed, 37 insertions, 0 deletions
diff --git a/src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg b/src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg Binary files differnew file mode 100644 index 0000000..d9753b6 --- /dev/null +++ b/src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg diff --git a/src/text_recognizer/tests/test_paragraph_text_recognizer.py b/src/text_recognizer/tests/test_paragraph_text_recognizer.py new file mode 100644 index 0000000..3e280b9 --- /dev/null +++ b/src/text_recognizer/tests/test_paragraph_text_recognizer.py @@ -0,0 +1,37 @@ +"""Test for ParagraphTextRecognizer class.""" +import os +from pathlib import Path +import unittest + +from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizor +import text_recognizer.util as util + + +SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" / "iam_paragraph" + +# Prevent using GPU. +os.environ["CUDA_VISIBLE_DEVICES"] = "" + + +class TestParagraphTextRecognizor(unittest.TestCase): + """Test that it can take non-square images of max dimension larger than 256px.""" + + def test_filename(self) -> None: + """Test model on support image.""" + line_predictor_args = { + "dataset": "EmnistLineDataset", + "network_fn": "CNNTransformer", + } + line_detector_args = {"dataset": "EmnistLineDataset", "network_fn": "UNet"} + model = ParagraphTextRecognizor( + line_predictor_args=line_predictor_args, + line_detector_args=line_detector_args, + ) + num_text_lines_by_name = {"a01-000u-cropped": 7} + for filename in (SUPPORT_DIRNAME).glob("*.jpg"): + full_image = util.read_image(str(filename), grayscale=True) + predicted_text, line_region_crops = model.predict(full_image) + print(predicted_text) + self.assertTrue( + len(line_region_crops), num_text_lines_by_name[filename.stem] + ) |