diff options
Diffstat (limited to 'src/training')
-rw-r--r-- | src/training/run_experiment.py | 20 |
1 files changed, 4 insertions, 16 deletions
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index 0510d5c..e6ae84c 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -9,7 +9,6 @@ import re from typing import Callable, Dict, List, Optional, Tuple, Type import warnings -import adabelief_pytorch import click from loguru import logger import numpy as np @@ -17,19 +16,17 @@ import torch from torchsummary import summary from tqdm import tqdm from training.gpu_manager import GPUManager -from training.trainer.callbacks import Callback, CallbackList +from training.trainer.callbacks import CallbackList from training.trainer.train import Trainer import wandb import yaml from text_recognizer.models import Model -from text_recognizer.networks import loss as custom_loss_module +from text_recognizer.networks.loss import loss as custom_loss_module EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" -DEFAULT_TRAIN_ARGS = {"batch_size": 64, "epochs": 16} - def _get_level(verbose: int) -> int: """Sets the logger level.""" @@ -107,11 +104,7 @@ def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dic criterion_args = experiment_config["criterion"].get("args", {}) or {} # Optimizers - if experiment_config["optimizer"]["type"] == "AdaBelief": - warnings.filterwarnings("ignore", category=UserWarning) - optimizer_ = getattr(adabelief_pytorch, experiment_config["optimizer"]["type"]) - else: - optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"]) + optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"]) optimizer_args = experiment_config["optimizer"].get("args", {}) # Learning rate scheduler @@ -277,11 +270,6 @@ def run_experiment( # Lets W&B save the model and track the gradients and optional parameters. wandb.watch(model.network) - experiment_config["train_args"] = { - **DEFAULT_TRAIN_ARGS, - **experiment_config.get("train_args", {}), - } - experiment_config["experiment_group"] = experiment_config.get( "experiment_group", None ) @@ -351,7 +339,7 @@ def run_experiment( "--pretrained_weights", type=str, help="Path to pretrained model weights." ) @click.option( - "--notrain", is_flag=False, is_eager=True, help="Do not train the model.", + "--notrain", is_flag=False, help="Do not train the model.", ) def run_cli( experiment_config: str, |