diff options
Diffstat (limited to 'training/run.py')
-rw-r--r-- | training/run.py | 19 |
1 files changed, 12 insertions, 7 deletions
diff --git a/training/run.py b/training/run.py index 30479c6..13a6a82 100644 --- a/training/run.py +++ b/training/run.py @@ -12,35 +12,40 @@ from pytorch_lightning import ( Trainer, ) from pytorch_lightning.loggers import LightningLoggerBase -from text_recognizer.data.mappings import AbstractMapping from torch import nn +from text_recognizer.data.base_mapping import AbstractMapping import utils def run(config: DictConfig) -> Optional[float]: """Runs experiment.""" - utils.configure_logging(config.logging) + utils.configure_logging(config) log.info("Starting experiment...") if config.get("seed"): - seed_everything(config.seed) + seed_everything(config.seed, workers=True) log.info(f"Instantiating mapping <{config.mapping._target_}>") mapping: AbstractMapping = hydra.utils.instantiate(config.mapping) log.info(f"Instantiating datamodule <{config.datamodule._target_}>") - datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule, mapping=mapping) + datamodule: LightningDataModule = hydra.utils.instantiate( + config.datamodule, mapping=mapping + ) log.info(f"Instantiating network <{config.network._target_}>") network: nn.Module = hydra.utils.instantiate(config.network) + log.info(f"Instantiating criterion <{config.criterion._target_}>") + loss_fn: Type[nn.Module] = hydra.utils.instantiate(config.criterion) + log.info(f"Instantiating model <{config.model._target_}>") model: LightningModule = hydra.utils.instantiate( - **config.model, + config.model, mapping=mapping, network=network, - criterion_config=config.criterion, + loss_fn=loss_fn, optimizer_config=config.optimizer, lr_scheduler_config=config.lr_scheduler, _recursive_=False, @@ -77,4 +82,4 @@ def run(config: DictConfig) -> Optional[float]: trainer.test(model, datamodule=datamodule) log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}") - utils.finish(trainer) + utils.finish(logger) |