summaryrefslogtreecommitdiff
path: root/training/run.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-03 18:18:48 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-03 18:18:48 +0200
commitbd4bd443f339e95007bfdabf3e060db720f4d4b9 (patch)
treee55cb3744904f7c2a0348b100c7e92a65e538a16 /training/run.py
parent75801019981492eedf9280cb352eea3d8e99b65f (diff)
Training working, multiple bug fixes
Diffstat (limited to 'training/run.py')
-rw-r--r--training/run.py19
1 files changed, 12 insertions, 7 deletions
diff --git a/training/run.py b/training/run.py
index 30479c6..13a6a82 100644
--- a/training/run.py
+++ b/training/run.py
@@ -12,35 +12,40 @@ from pytorch_lightning import (
Trainer,
)
from pytorch_lightning.loggers import LightningLoggerBase
-from text_recognizer.data.mappings import AbstractMapping
from torch import nn
+from text_recognizer.data.base_mapping import AbstractMapping
import utils
def run(config: DictConfig) -> Optional[float]:
"""Runs experiment."""
- utils.configure_logging(config.logging)
+ utils.configure_logging(config)
log.info("Starting experiment...")
if config.get("seed"):
- seed_everything(config.seed)
+ seed_everything(config.seed, workers=True)
log.info(f"Instantiating mapping <{config.mapping._target_}>")
mapping: AbstractMapping = hydra.utils.instantiate(config.mapping)
log.info(f"Instantiating datamodule <{config.datamodule._target_}>")
- datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule, mapping=mapping)
+ datamodule: LightningDataModule = hydra.utils.instantiate(
+ config.datamodule, mapping=mapping
+ )
log.info(f"Instantiating network <{config.network._target_}>")
network: nn.Module = hydra.utils.instantiate(config.network)
+ log.info(f"Instantiating criterion <{config.criterion._target_}>")
+ loss_fn: Type[nn.Module] = hydra.utils.instantiate(config.criterion)
+
log.info(f"Instantiating model <{config.model._target_}>")
model: LightningModule = hydra.utils.instantiate(
- **config.model,
+ config.model,
mapping=mapping,
network=network,
- criterion_config=config.criterion,
+ loss_fn=loss_fn,
optimizer_config=config.optimizer,
lr_scheduler_config=config.lr_scheduler,
_recursive_=False,
@@ -77,4 +82,4 @@ def run(config: DictConfig) -> Optional[float]:
trainer.test(model, datamodule=datamodule)
log.info(f"Best checkpoint path:\n{trainer.checkpoint_callback.best_model_path}")
- utils.finish(trainer)
+ utils.finish(logger)