summaryrefslogtreecommitdiff
path: root/src/training
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-18 20:56:19 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-18 20:56:19 +0100
commit527bb98b191d82b308de1585047e06056258d08d (patch)
treef33145dba398825871da3184a2735f6fb0b07268 /src/training
parentf2cd16f340aa11afadb8fa90c29f85ca1b75a600 (diff)
Some minor changes.
Diffstat (limited to 'src/training')
-rw-r--r--src/training/run_experiment.py8
-rw-r--r--src/training/trainer/train.py4
2 files changed, 8 insertions, 4 deletions
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index 55a9572..a883b45 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -21,8 +21,9 @@ from training.trainer.train import Trainer
import wandb
import yaml
-
+import text_recognizer.models
from text_recognizer.models import Model
+import text_recognizer.networks
from text_recognizer.networks.loss import loss as custom_loss_module
EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
@@ -77,13 +78,12 @@ def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dic
dataset_ = dataset_args["type"]
# Import the model module and model arguments.
- models_module = importlib.import_module("text_recognizer.models")
- model_class_ = getattr(models_module, experiment_config["model"])
+ model_class_ = getattr(text_recognizer.models, experiment_config["model"])
# Import metrics.
metric_fns_ = (
{
- metric: getattr(models_module, metric)
+ metric: getattr(text_recognizer.networks, metric)
for metric in experiment_config["metrics"]
}
if experiment_config["metrics"] is not None
diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py
index 223d9c6..8ae994a 100644
--- a/src/training/trainer/train.py
+++ b/src/training/trainer/train.py
@@ -3,6 +3,7 @@
from pathlib import Path
import time
from typing import Dict, List, Optional, Tuple, Type
+import warnings
from einops import rearrange
from loguru import logger
@@ -23,6 +24,9 @@ torch.manual_seed(4711)
torch.cuda.manual_seed(4711)
+warnings.filterwarnings("ignore")
+
+
class Trainer:
"""Trainer for training PyTorch models."""