summaryrefslogtreecommitdiff
path: root/src/text_recognizer/tests
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/tests')
-rw-r--r--src/text_recognizer/tests/test_character_predictor.py14
1 files changed, 2 insertions, 12 deletions
diff --git a/src/text_recognizer/tests/test_character_predictor.py b/src/text_recognizer/tests/test_character_predictor.py
index c603a3a..01bda78 100644
--- a/src/text_recognizer/tests/test_character_predictor.py
+++ b/src/text_recognizer/tests/test_character_predictor.py
@@ -4,7 +4,6 @@ import os
from pathlib import Path
import unittest
-import click
from loguru import logger
from text_recognizer.character_predictor import CharacterPredictor
@@ -18,19 +17,10 @@ os.environ["CUDA_VISIBLE_DEVICES"] = ""
class TestCharacterPredictor(unittest.TestCase):
"""Tests for the CharacterPredictor class."""
- # @click.command()
- # @click.option(
- # "--network", type=str, help="Network to load, e.g. MLP or LeNet.", default="MLP"
- # )
def test_filename(self) -> None:
"""Test that CharacterPredictor correctly predicts on a single image, for serveral test images."""
- network_module = importlib.import_module("text_recognizer.networks")
- network_fn_ = getattr(network_module, "MLP")
- # network_args = {"input_size": [28, 28], "output_size": 62, "dropout_rate": 0}
- network_args = {"input_size": 784, "output_size": 62, "dropout_rate": 0.2}
- predictor = CharacterPredictor(
- network_fn=network_fn_, network_args=network_args
- )
+ network_fn_ = MLP
+ predictor = CharacterPredictor(network_fn=network_fn_)
for filename in SUPPORT_DIRNAME.glob("*.png"):
pred, conf = predictor.predict(str(filename))