summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--training/configs/image_transformer.yaml (renamed from text_recognizer/training/experiments/image_transformer.yaml)0
-rw-r--r--training/run_experiment.py (renamed from text_recognizer/training/run_experiment.py)20
2 files changed, 11 insertions, 9 deletions
diff --git a/text_recognizer/training/experiments/image_transformer.yaml b/training/configs/image_transformer.yaml
index bedcbb5..bedcbb5 100644
--- a/text_recognizer/training/experiments/image_transformer.yaml
+++ b/training/configs/image_transformer.yaml
diff --git a/text_recognizer/training/run_experiment.py b/training/run_experiment.py
index ed1a947..f46803f 100644
--- a/text_recognizer/training/run_experiment.py
+++ b/training/run_experiment.py
@@ -1,5 +1,4 @@
"""Script to run experiments."""
-from datetime import datetime
import importlib
from pathlib import Path
from typing import Dict, List, Optional, Type
@@ -10,13 +9,13 @@ from omegaconf import DictConfig, OmegaConf
import pytorch_lightning as pl
import torch
from torch import nn
-from torchsummary import summary
from tqdm import tqdm
import wandb
SEED = 4711
-EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments"
+CONFIGS_DIRNAME = Path(__file__).parent.resolve() / "configs"
+LOGS_DIRNAME = Path(__file__).parent.resolve() / "runs" / "logs"
def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None:
@@ -68,15 +67,15 @@ def _configure_callbacks(
def _configure_logger(
- network: Type[nn.Module], args: Dict, use_wandb: bool
+ network: Type[nn.Module], args: Dict, log_dir: str, use_wandb: bool
) -> Type[pl.loggers.LightningLoggerBase]:
"""Configures lightning logger."""
if use_wandb:
- pl_logger = pl.loggers.WandbLogger()
+ pl_logger = pl.loggers.WandbLogger(save_dir=log_dir)
pl_logger.watch(network)
pl_logger.log_hyperparams(vars(args))
return pl_logger
- return pl.logger.TensorBoardLogger("training/logs")
+ return pl.logger.TensorBoardLogger(save_dir=log_dir)
def _save_best_weights(
@@ -119,16 +118,19 @@ def run(
verbose: int = 0,
) -> None:
"""Runs experiment."""
+ # Set log dir where logging output and weights are saved to.
+ log_dir = str(LOGS_DIRNAME)
_configure_logging(None, verbose=verbose)
logger.info("Starting experiment...")
+
# Seed everything in the experiment.
logger.info(f"Seeding everthing with seed={SEED}")
pl.utilities.seed.seed_everything(SEED)
# Load config.
- file_path = EXPERIMENTS_DIRNAME / filename
+ file_path = CONFIGS_DIRNAME / filename
config = _load_config(file_path)
# Load classes.
@@ -142,7 +144,7 @@ def run(
# Load callback and logger.
callbacks = _configure_callbacks(config.callbacks)
- pl_logger = _configure_logger(network, config, use_wandb)
+ pl_logger = _configure_logger(network, config, log_dir, use_wandb)
# Load ligtning model.
lit_model = _load_lit_model(lit_model_class, network, config)
@@ -151,7 +153,7 @@ def run(
**config.trainer.args,
callbacks=callbacks,
logger=pl_logger,
- weigths_save_path="training/logs",
+ weigths_save_path=log_dir,
)
if tune: