diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-30 23:58:15 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-30 23:58:15 +0200 |
commit | 41c4d214f754295c5cf0b5b978e81b069336b574 (patch) | |
tree | bdc8c4ae718df2ce141563acbd85d8cdf3ce48df /training | |
parent | 54cca2f8d64b5e1443ad0f9c2ceb8ed9260d0810 (diff) |
Update run script with torchinfo
Diffstat (limited to 'training')
-rw-r--r-- | training/run.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/training/run.py b/training/run.py index ddae2b9..b854376 100644 --- a/training/run.py +++ b/training/run.py @@ -13,7 +13,7 @@ from pytorch_lightning import ( ) from pytorch_lightning.loggers import LightningLoggerBase from torch import nn -from torchsummary import summary +from torchinfo import summary from text_recognizer.data.base_mapping import AbstractMapping import utils @@ -38,9 +38,6 @@ def run(config: DictConfig) -> Optional[float]: log.info(f"Instantiating network <{config.network._target_}>") network: nn.Module = hydra.utils.instantiate(config.network) - if config.summary: - summary(network, tuple(config.summary), device="cpu") - log.info(f"Instantiating criterion <{config.criterion._target_}>") loss_fn: Type[nn.Module] = hydra.utils.instantiate(config.criterion) @@ -69,6 +66,11 @@ def run(config: DictConfig) -> Optional[float]: utils.log_hyperparameters(config=config, model=model, trainer=trainer) utils.save_config(config) + if config.get("summary"): + summary( + network, list(map(lambda x: list(x), config.summary)), depth=1, device="cpu" + ) + if config.debug: log.info("Fast development run...") trainer.fit(model, datamodule=datamodule) |