From 75801019981492eedf9280cb352eea3d8e99b65f Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 2 Aug 2021 21:13:48 +0200 Subject: Fix log import, fix mapping in datamodules, fix nn modules can be hashed --- training/run.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) (limited to 'training/run.py') diff --git a/training/run.py b/training/run.py index d88a8f6..30479c6 100644 --- a/training/run.py +++ b/training/run.py @@ -2,7 +2,7 @@ from typing import List, Optional, Type import hydra -import loguru.logger as log +from loguru import logger as log from omegaconf import DictConfig from pytorch_lightning import ( Callback, @@ -12,6 +12,7 @@ from pytorch_lightning import ( Trainer, ) from pytorch_lightning.loggers import LightningLoggerBase +from text_recognizer.data.mappings import AbstractMapping from torch import nn import utils @@ -25,15 +26,19 @@ def run(config: DictConfig) -> Optional[float]: if config.get("seed"): seed_everything(config.seed) + log.info(f"Instantiating mapping <{config.mapping._target_}>") + mapping: AbstractMapping = hydra.utils.instantiate(config.mapping) + log.info(f"Instantiating datamodule <{config.datamodule._target_}>") - datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) + datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule, mapping=mapping) log.info(f"Instantiating network <{config.network._target_}>") - network: nn.Module = hydra.utils.instantiate(config.network, **datamodule.config()) + network: nn.Module = hydra.utils.instantiate(config.network) log.info(f"Instantiating model <{config.model._target_}>") model: LightningModule = hydra.utils.instantiate( **config.model, + mapping=mapping, network=network, criterion_config=config.criterion, optimizer_config=config.optimizer, -- cgit v1.2.3-70-g09d2