summaryrefslogtreecommitdiff
path: root/src/training
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-12 23:42:03 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-12 23:42:03 +0100
commit8fdb6435e15703fa5b76df19728d905650ee1aef (patch)
treebe3bec9e5cab4ef7f9d94528d102e57ce9b16c3f /src/training
parentdc28cbe2b4ed77be92ee8b2b69a20689c3bf02a4 (diff)
parent6cb08a110620ee09fe9d8a5d008197a801d025df (diff)
Working cnn transformer.
Diffstat (limited to 'src/training')
-rw-r--r--src/training/run_experiment.py20
1 files changed, 4 insertions, 16 deletions
diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py
index 0510d5c..e6ae84c 100644
--- a/src/training/run_experiment.py
+++ b/src/training/run_experiment.py
@@ -9,7 +9,6 @@ import re
from typing import Callable, Dict, List, Optional, Tuple, Type
import warnings
-import adabelief_pytorch
import click
from loguru import logger
import numpy as np
@@ -17,19 +16,17 @@ import torch
from torchsummary import summary
from tqdm import tqdm
from training.gpu_manager import GPUManager
-from training.trainer.callbacks import Callback, CallbackList
+from training.trainer.callbacks import CallbackList
from training.trainer.train import Trainer
import wandb
import yaml
from text_recognizer.models import Model
-from text_recognizer.networks import loss as custom_loss_module
+from text_recognizer.networks.loss import loss as custom_loss_module
EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
-DEFAULT_TRAIN_ARGS = {"batch_size": 64, "epochs": 16}
-
def _get_level(verbose: int) -> int:
"""Sets the logger level."""
@@ -107,11 +104,7 @@ def _load_modules_and_arguments(experiment_config: Dict,) -> Tuple[Callable, Dic
criterion_args = experiment_config["criterion"].get("args", {}) or {}
# Optimizers
- if experiment_config["optimizer"]["type"] == "AdaBelief":
- warnings.filterwarnings("ignore", category=UserWarning)
- optimizer_ = getattr(adabelief_pytorch, experiment_config["optimizer"]["type"])
- else:
- optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"])
+ optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"])
optimizer_args = experiment_config["optimizer"].get("args", {})
# Learning rate scheduler
@@ -277,11 +270,6 @@ def run_experiment(
# Lets W&B save the model and track the gradients and optional parameters.
wandb.watch(model.network)
- experiment_config["train_args"] = {
- **DEFAULT_TRAIN_ARGS,
- **experiment_config.get("train_args", {}),
- }
-
experiment_config["experiment_group"] = experiment_config.get(
"experiment_group", None
)
@@ -351,7 +339,7 @@ def run_experiment(
"--pretrained_weights", type=str, help="Path to pretrained model weights."
)
@click.option(
- "--notrain", is_flag=False, is_eager=True, help="Do not train the model.",
+ "--notrain", is_flag=False, help="Do not train the model.",
)
def run_cli(
experiment_config: str,