summaryrefslogtreecommitdiff
path: root/src/text_recognizer/tests
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/tests')
-rw-r--r--src/text_recognizer/tests/support/create_emnist_support_files.py13
-rw-r--r--src/text_recognizer/tests/test_line_predictor.py35
2 files changed, 41 insertions, 7 deletions
diff --git a/src/text_recognizer/tests/support/create_emnist_support_files.py b/src/text_recognizer/tests/support/create_emnist_support_files.py
index 5dd1a81..c04860d 100644
--- a/src/text_recognizer/tests/support/create_emnist_support_files.py
+++ b/src/text_recognizer/tests/support/create_emnist_support_files.py
@@ -2,10 +2,8 @@
from pathlib import Path
import shutil
-from text_recognizer.datasets.emnist_dataset import (
- fetch_emnist_dataset,
- load_emnist_mapping,
-)
+from text_recognizer.datasets.emnist_dataset import EmnistDataset
+from text_recognizer.datasets.util import EmnistMapper
from text_recognizer.util import write_image
SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "emnist"
@@ -16,15 +14,16 @@ def create_emnist_support_files() -> None:
shutil.rmtree(SUPPORT_DIRNAME, ignore_errors=True)
SUPPORT_DIRNAME.mkdir()
- dataset = fetch_emnist_dataset(split="byclass", train=False)
- mapping = load_emnist_mapping()
+ dataset = EmnistDataset(train=False)
+ dataset.load_or_generate_data()
+ mapping = EmnistMapper()
for index in [5, 7, 9]:
image, label = dataset[index]
if len(image.shape) == 3:
image = image.squeeze(0)
image = image.numpy()
- label = mapping[int(label)]
+ label = mapping(int(label))
print(index, label)
write_image(image, str(SUPPORT_DIRNAME / f"{label}.png"))
diff --git a/src/text_recognizer/tests/test_line_predictor.py b/src/text_recognizer/tests/test_line_predictor.py
new file mode 100644
index 0000000..eede4d4
--- /dev/null
+++ b/src/text_recognizer/tests/test_line_predictor.py
@@ -0,0 +1,35 @@
+"""Tests for LinePredictor."""
+import os
+from pathlib import Path
+import unittest
+
+
+import editdistance
+import numpy as np
+
+from text_recognizer.datasets import IamLinesDataset
+from text_recognizer.line_predictor import LinePredictor
+import text_recognizer.util as util
+
+SUPPORT_DIRNAME = Path(__file__).parents[0].resolve() / "support"
+
+os.environ["CUDA_VISIBLE_DEVICES"] = ""
+
+
+class TestEmnistLinePredictor(unittest.TestCase):
+ """Test LinePredictor class on the EmnistLines dataset."""
+
+ def test_filename(self) -> None:
+ """Test that LinePredictor correctly predicts on single images, for several test images."""
+ predictor = LinePredictor(
+ dataset="EmnistLineDataset", network_fn="CNNTransformer"
+ )
+
+ for filename in (SUPPORT_DIRNAME / "emnist_lines").glob("*.png"):
+ pred, conf = predictor.predict(str(filename))
+ true = str(filename.stem)
+ edit_distance = editdistance.eval(pred, true) / len(pred)
+ print(
+ f'Pred: "{pred}" | Confidence: {conf} | True: {true} | Edit distance: {edit_distance}'
+ )
+ self.assertLess(edit_distance, 0.2)