summaryrefslogtreecommitdiff
path: root/training/run.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/run.py')
-rw-r--r--training/run.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/training/run.py b/training/run.py
index d88a8f6..30479c6 100644
--- a/training/run.py
+++ b/training/run.py
@@ -2,7 +2,7 @@
from typing import List, Optional, Type
import hydra
-import loguru.logger as log
+from loguru import logger as log
from omegaconf import DictConfig
from pytorch_lightning import (
Callback,
@@ -12,6 +12,7 @@ from pytorch_lightning import (
Trainer,
)
from pytorch_lightning.loggers import LightningLoggerBase
+from text_recognizer.data.mappings import AbstractMapping
from torch import nn
import utils
@@ -25,15 +26,19 @@ def run(config: DictConfig) -> Optional[float]:
if config.get("seed"):
seed_everything(config.seed)
+ log.info(f"Instantiating mapping <{config.mapping._target_}>")
+ mapping: AbstractMapping = hydra.utils.instantiate(config.mapping)
+
log.info(f"Instantiating datamodule <{config.datamodule._target_}>")
- datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)
+ datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule, mapping=mapping)
log.info(f"Instantiating network <{config.network._target_}>")
- network: nn.Module = hydra.utils.instantiate(config.network, **datamodule.config())
+ network: nn.Module = hydra.utils.instantiate(config.network)
log.info(f"Instantiating model <{config.model._target_}>")
model: LightningModule = hydra.utils.instantiate(
**config.model,
+ mapping=mapping,
network=network,
criterion_config=config.criterion,
optimizer_config=config.optimizer,