summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-30 23:58:15 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-30 23:58:15 +0200
commit41c4d214f754295c5cf0b5b978e81b069336b574 (patch)
treebdc8c4ae718df2ce141563acbd85d8cdf3ce48df /training
parent54cca2f8d64b5e1443ad0f9c2ceb8ed9260d0810 (diff)
Update run script with torchinfo
Diffstat (limited to 'training')
-rw-r--r--training/run.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/training/run.py b/training/run.py
index ddae2b9..b854376 100644
--- a/training/run.py
+++ b/training/run.py
@@ -13,7 +13,7 @@ from pytorch_lightning import (
)
from pytorch_lightning.loggers import LightningLoggerBase
from torch import nn
-from torchsummary import summary
+from torchinfo import summary
from text_recognizer.data.base_mapping import AbstractMapping
import utils
@@ -38,9 +38,6 @@ def run(config: DictConfig) -> Optional[float]:
log.info(f"Instantiating network <{config.network._target_}>")
network: nn.Module = hydra.utils.instantiate(config.network)
- if config.summary:
- summary(network, tuple(config.summary), device="cpu")
-
log.info(f"Instantiating criterion <{config.criterion._target_}>")
loss_fn: Type[nn.Module] = hydra.utils.instantiate(config.criterion)
@@ -69,6 +66,11 @@ def run(config: DictConfig) -> Optional[float]:
utils.log_hyperparameters(config=config, model=model, trainer=trainer)
utils.save_config(config)
+ if config.get("summary"):
+ summary(
+ network, list(map(lambda x: list(x), config.summary)), depth=1, device="cpu"
+ )
+
if config.debug:
log.info("Fast development run...")
trainer.fit(model, datamodule=datamodule)