diff options
Diffstat (limited to 'training/run.py')
| -rw-r--r-- | training/run.py | 11 | 
1 files changed, 8 insertions, 3 deletions
| diff --git a/training/run.py b/training/run.py index d88a8f6..30479c6 100644 --- a/training/run.py +++ b/training/run.py @@ -2,7 +2,7 @@  from typing import List, Optional, Type  import hydra -import loguru.logger as log +from loguru import logger as log  from omegaconf import DictConfig  from pytorch_lightning import (      Callback, @@ -12,6 +12,7 @@ from pytorch_lightning import (      Trainer,  )  from pytorch_lightning.loggers import LightningLoggerBase +from text_recognizer.data.mappings import AbstractMapping  from torch import nn  import utils @@ -25,15 +26,19 @@ def run(config: DictConfig) -> Optional[float]:      if config.get("seed"):          seed_everything(config.seed) +    log.info(f"Instantiating mapping <{config.mapping._target_}>") +    mapping: AbstractMapping = hydra.utils.instantiate(config.mapping) +      log.info(f"Instantiating datamodule <{config.datamodule._target_}>") -    datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) +    datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule, mapping=mapping)      log.info(f"Instantiating network <{config.network._target_}>") -    network: nn.Module = hydra.utils.instantiate(config.network, **datamodule.config()) +    network: nn.Module = hydra.utils.instantiate(config.network)      log.info(f"Instantiating model <{config.model._target_}>")      model: LightningModule = hydra.utils.instantiate(          **config.model, +        mapping=mapping,          network=network,          criterion_config=config.criterion,          optimizer_config=config.optimizer, |