summaryrefslogtreecommitdiff
path: root/src/training/run_experiment.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/training/run_experiment.py')
-rw-r--r--src/training/run_experiment.py20
1 files changed, 4 insertions, 16 deletions
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index 0510d5c..e6ae84c 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -9,7 +9,6 @@ import re
from typing import Callable, Dict, List, Optional, Tuple, Type
import warnings
-import adabelief_pytorch
import click
from loguru import logger
import numpy as np
@@ -17,19 +16,17 @@ import torch
from torchsummary import summary
from tqdm import tqdm
from training.gpu_manager import GPUManager
-from training.trainer.callbacks import Callback, CallbackList
+from training.trainer.callbacks import CallbackList
from training.trainer.train import Trainer
import wandb
import yaml
from text_recognizer.models import Model
-from text_recognizer.networks import loss as custom_loss_module
+from text_recognizer.networks.loss import loss as custom_loss_module
EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
-DEFAULT_TRAIN_ARGS = {"batch_size": 64, "epochs": 16}
-
def _get_level(verbose: int) -> int:
"""Sets the logger level."""
@@ -107,11 +104,7 @@ def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dic
criterion_args = experiment_config["criterion"].get("args", {}) or {}
# Optimizers
- if experiment_config["optimizer"]["type"] == "AdaBelief":
- warnings.filterwarnings("ignore", category=UserWarning)
- optimizer_ = getattr(adabelief_pytorch, experiment_config["optimizer"]["type"])
- else:
- optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"])
+ optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"])
optimizer_args = experiment_config["optimizer"].get("args", {})
# Learning rate scheduler
@@ -277,11 +270,6 @@ def run_experiment(
# Lets W&B save the model and track the gradients and optional parameters.
wandb.watch(model.network)
- experiment_config["train_args"] = {
- **DEFAULT_TRAIN_ARGS,
- **experiment_config.get("train_args", {}),
- }
-
experiment_config["experiment_group"] = experiment_config.get(
"experiment_group", None
)
@@ -351,7 +339,7 @@ def run_experiment(
"--pretrained_weights", type=str, help="Path to pretrained model weights."
)
@click.option(
- "--notrain", is_flag=False, is_eager=True, help="Do not train the model.",
+ "--notrain", is_flag=False, help="Do not train the model.",
)
def run_cli(
experiment_config: str,