From 41c4d214f754295c5cf0b5b978e81b069336b574 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Thu, 30 Sep 2021 23:58:15 +0200
Subject: Update run script with torchinfo

---
 training/run.py | 10 ++++++----
 1 file changed, 6 insertions(+), 4 deletions(-)

diff --git a/training/run.py b/training/run.py
index ddae2b9..b854376 100644
--- a/training/run.py
+++ b/training/run.py
@@ -13,7 +13,7 @@ from pytorch_lightning import (
 )
 from pytorch_lightning.loggers import LightningLoggerBase
 from torch import nn
-from torchsummary import summary
+from torchinfo import summary
 from text_recognizer.data.base_mapping import AbstractMapping
 
 import utils
@@ -38,9 +38,6 @@ def run(config: DictConfig) -> Optional[float]:
     log.info(f"Instantiating network <{config.network._target_}>")
     network: nn.Module = hydra.utils.instantiate(config.network)
 
-    if config.summary:
-        summary(network, tuple(config.summary), device="cpu")
-
     log.info(f"Instantiating criterion <{config.criterion._target_}>")
     loss_fn: Type[nn.Module] = hydra.utils.instantiate(config.criterion)
 
@@ -69,6 +66,11 @@ def run(config: DictConfig) -> Optional[float]:
     utils.log_hyperparameters(config=config, model=model, trainer=trainer)
     utils.save_config(config)
 
+    if config.get("summary"):
+        summary(
+            network, list(map(lambda x: list(x), config.summary)), depth=1, device="cpu"
+        )
+
     if config.debug:
         log.info("Fast development run...")
         trainer.fit(model, datamodule=datamodule)
-- 
cgit v1.2.3-70-g09d2