diff options
Diffstat (limited to 'src/text_recognizer/tests/test_character_predictor.py')
-rw-r--r-- | src/text_recognizer/tests/test_character_predictor.py | 14 |
1 files changed, 2 insertions, 12 deletions
diff --git a/src/text_recognizer/tests/test_character_predictor.py b/src/text_recognizer/tests/test_character_predictor.py index c603a3a..01bda78 100644 --- a/src/text_recognizer/tests/test_character_predictor.py +++ b/src/text_recognizer/tests/test_character_predictor.py @@ -4,7 +4,6 @@ import os from pathlib import Path import unittest -import click from loguru import logger from text_recognizer.character_predictor import CharacterPredictor @@ -18,19 +17,10 @@ os.environ["CUDA_VISIBLE_DEVICES"] = "" class TestCharacterPredictor(unittest.TestCase): """Tests for the CharacterPredictor class.""" - # @click.command() - # @click.option( - # "--network", type=str, help="Network to load, e.g. MLP or LeNet.", default="MLP" - # ) def test_filename(self) -> None: """Test that CharacterPredictor correctly predicts on a single image, for serveral test images.""" - network_module = importlib.import_module("text_recognizer.networks") - network_fn_ = getattr(network_module, "MLP") - # network_args = {"input_size": [28, 28], "output_size": 62, "dropout_rate": 0} - network_args = {"input_size": 784, "output_size": 62, "dropout_rate": 0.2} - predictor = CharacterPredictor( - network_fn=network_fn_, network_args=network_args - ) + network_fn_ = MLP + predictor = CharacterPredictor(network_fn=network_fn_) for filename in SUPPORT_DIRNAME.glob("*.png"): pred, conf = predictor.predict(str(filename)) |