From bd4bd443f339e95007bfdabf3e060db720f4d4b9 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 3 Aug 2021 18:18:48 +0200 Subject: Training working, multiple bug fixes --- training/run.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) (limited to 'training/run.py') diff --git a/training/run.py b/training/run.py index 30479c6..13a6a82 100644 --- a/training/run.py +++ b/training/run.py @@ -12,35 +12,40 @@ from pytorch_lightning import ( Trainer, ) from pytorch_lightning.loggers import LightningLoggerBase -from text_recognizer.data.mappings import AbstractMapping from torch import nn +from text_recognizer.data.base_mapping import AbstractMapping import utils def run(config: DictConfig) -> Optional[float]: """Runs experiment.""" - utils.configure_logging(config.logging) + utils.configure_logging(config) log.info("Starting experiment...") if config.get("seed"): - seed_everything(config.seed) + seed_everything(config.seed, workers=True) log.info(f"Instantiating mapping <{config.mapping._target_}>") mapping: AbstractMapping = hydra.utils.instantiate(config.mapping) log.info(f"Instantiating datamodule <{config.datamodule._target_}>") - datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule, mapping=mapping) + datamodule: LightningDataModule = hydra.utils.instantiate( + config.datamodule, mapping=mapping + ) log.info(f"Instantiating network <{config.network._target_}>") network: nn.Module = hydra.utils.instantiate(config.network) + log.info(f"Instantiating criterion <{config.criterion._target_}>") + loss_fn: Type[nn.Module] = hydra.utils.instantiate(config.criterion) + log.info(f"Instantiating model <{config.model._target_}>") model: LightningModule = hydra.utils.instantiate( - **config.model, + config.model, mapping=mapping, network=network, - criterion_config=config.criterion, + loss_fn=loss_fn, optimizer_config=config.optimizer, lr_scheduler_config=config.lr_scheduler, _recursive_=False, @@ -77,4 +82,4 @@ def run(config: DictConfig) -> Optional[float]: trainer.test(model, datamodule=datamodule) log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}") - utils.finish(trainer) + utils.finish(logger) -- cgit v1.2.3-70-g09d2