summaryrefslogtreecommitdiff
path: root/src/training/run_experiment.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/run_experiment.py')
-rw-r--r--src/training/run_experiment.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index 55a9572..a883b45 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -21,8 +21,9 @@ from training.trainer.train import Trainer
import wandb
import yaml
-
+import text_recognizer.models
from text_recognizer.models import Model
+import text_recognizer.networks
from text_recognizer.networks.loss import loss as custom_loss_module
EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
@@ -77,13 +78,12 @@ def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dic
dataset_ = dataset_args["type"]
# Import the model module and model arguments.
- models_module = importlib.import_module("text_recognizer.models")
- model_class_ = getattr(models_module, experiment_config["model"])
+ model_class_ = getattr(text_recognizer.models, experiment_config["model"])
# Import metrics.
metric_fns_ = (
{
- metric: getattr(models_module, metric)
+ metric: getattr(text_recognizer.networks, metric)
for metric in experiment_config["metrics"]
}
if experiment_config["metrics"] is not None