diff options
-rw-r--r-- | training/configs/image_transformer.yaml (renamed from text_recognizer/training/experiments/image_transformer.yaml) | 0 | ||||
-rw-r--r-- | training/run_experiment.py (renamed from text_recognizer/training/run_experiment.py) | 20 |
2 files changed, 11 insertions, 9 deletions
diff --git a/text_recognizer/training/experiments/image_transformer.yaml b/training/configs/image_transformer.yaml index bedcbb5..bedcbb5 100644 --- a/text_recognizer/training/experiments/image_transformer.yaml +++ b/training/configs/image_transformer.yaml diff --git a/text_recognizer/training/run_experiment.py b/training/run_experiment.py index ed1a947..f46803f 100644 --- a/text_recognizer/training/run_experiment.py +++ b/training/run_experiment.py @@ -1,5 +1,4 @@ """Script to run experiments.""" -from datetime import datetime import importlib from pathlib import Path from typing import Dict, List, Optional, Type @@ -10,13 +9,13 @@ from omegaconf import DictConfig, OmegaConf import pytorch_lightning as pl import torch from torch import nn -from torchsummary import summary from tqdm import tqdm import wandb SEED = 4711 -EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" +CONFIGS_DIRNAME = Path(__file__).parent.resolve() / "configs" +LOGS_DIRNAME = Path(__file__).parent.resolve() / "runs" / "logs" def _configure_logging(log_dir: Optional[Path], verbose: int = 0) -> None: @@ -68,15 +67,15 @@ def _configure_callbacks( def _configure_logger( - network: Type[nn.Module], args: Dict, use_wandb: bool + network: Type[nn.Module], args: Dict, log_dir: str, use_wandb: bool ) -> Type[pl.loggers.LightningLoggerBase]: """Configures lightning logger.""" if use_wandb: - pl_logger = pl.loggers.WandbLogger() + pl_logger = pl.loggers.WandbLogger(save_dir=log_dir) pl_logger.watch(network) pl_logger.log_hyperparams(vars(args)) return pl_logger - return pl.logger.TensorBoardLogger("training/logs") + return pl.logger.TensorBoardLogger(save_dir=log_dir) def _save_best_weights( @@ -119,16 +118,19 @@ def run( verbose: int = 0, ) -> None: """Runs experiment.""" + # Set log dir where logging output and weights are saved to. + log_dir = str(LOGS_DIRNAME) _configure_logging(None, verbose=verbose) logger.info("Starting experiment...") + # Seed everything in the experiment. logger.info(f"Seeding everthing with seed={SEED}") pl.utilities.seed.seed_everything(SEED) # Load config. - file_path = EXPERIMENTS_DIRNAME / filename + file_path = CONFIGS_DIRNAME / filename config = _load_config(file_path) # Load classes. @@ -142,7 +144,7 @@ def run( # Load callback and logger. callbacks = _configure_callbacks(config.callbacks) - pl_logger = _configure_logger(network, config, use_wandb) + pl_logger = _configure_logger(network, config, log_dir, use_wandb) # Load ligtning model. lit_model = _load_lit_model(lit_model_class, network, config) @@ -151,7 +153,7 @@ def run( **config.trainer.args, callbacks=callbacks, logger=pl_logger, - weigths_save_path="training/logs", + weigths_save_path=log_dir, ) if tune: |