From 25b5d6983d51e0e791b96a76beb7e49f392cd9a8 Mon Sep 17 00:00:00 2001 From: aktersnurra Date: Mon, 7 Dec 2020 22:54:04 +0100 Subject: Segmentation working! --- .../tests/support/iam_paragraphs/a01-000u.jpg | Bin 0 -> 14890 bytes .../tests/test_paragraph_text_recognizer.py | 37 +++++++++++++++++++++ 2 files changed, 37 insertions(+) create mode 100644 src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg create mode 100644 src/text_recognizer/tests/test_paragraph_text_recognizer.py (limited to 'src/text_recognizer/tests') diff --git a/src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg b/src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg new file mode 100644 index 0000000..d9753b6 Binary files /dev/null and b/src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg differ diff --git a/src/text_recognizer/tests/test_paragraph_text_recognizer.py b/src/text_recognizer/tests/test_paragraph_text_recognizer.py new file mode 100644 index 0000000..3e280b9 --- /dev/null +++ b/src/text_recognizer/tests/test_paragraph_text_recognizer.py @@ -0,0 +1,37 @@ +"""Test for ParagraphTextRecognizer class.""" +import os +from pathlib import Path +import unittest + +from text_recognizer.paragraph_text_recognizer import ParagraphTextRecognizor +import text_recognizer.util as util + + +SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support" / "iam_paragraph" + +# Prevent using GPU. +os.environ["CUDA_VISIBLE_DEVICES"] = "" + + +class TestParagraphTextRecognizor(unittest.TestCase): + """Test that it can take non-square images of max dimension larger than 256px.""" + + def test_filename(self) -> None: + """Test model on support image.""" + line_predictor_args = { + "dataset": "EmnistLineDataset", + "network_fn": "CNNTransformer", + } + line_detector_args = {"dataset": "EmnistLineDataset", "network_fn": "UNet"} + model = ParagraphTextRecognizor( + line_predictor_args=line_predictor_args, + line_detector_args=line_detector_args, + ) + num_text_lines_by_name = {"a01-000u-cropped": 7} + for filename in (SUPPORT_DIRNAME).glob("*.jpg"): + full_image = util.read_image(str(filename), grayscale=True) + predicted_text, line_region_crops = model.predict(full_image) + print(predicted_text) + self.assertTrue( + len(line_region_crops), num_text_lines_by_name[filename.stem] + ) -- cgit v1.2.3-70-g09d2