summaryrefslogtreecommitdiff
path: root/training/run.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-06 02:42:45 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-06 02:42:45 +0200
commit3ab82ad36bce6fa698a13a029a0694b75a5947b7 (patch)
tree136f71a62d60e3ccf01e1f95d64bb4d9f9c9befe /training/run.py
parent1bccf71cf4eec335001b50a8fbc0c991d0e6d13a (diff)
Fix VQVAE into en/decoder, bug in wandb artifact code uploading
Diffstat (limited to 'training/run.py')
-rw-r--r--training/run.py4
1 files changed, 4 insertions, 0 deletions
diff --git a/training/run.py b/training/run.py
index 13a6a82..a2529b0 100644
--- a/training/run.py
+++ b/training/run.py
@@ -13,6 +13,7 @@ from pytorch_lightning import (
)
from pytorch_lightning.loggers import LightningLoggerBase
from torch import nn
+from torchsummary import summary
from text_recognizer.data.base_mapping import AbstractMapping
import utils
@@ -37,6 +38,9 @@ def run(config: DictConfig) -> Optional[float]:
log.info(f"Instantiating network <{config.network._target_}>")
network: nn.Module = hydra.utils.instantiate(config.network)
+ if config.summary:
+ summary(network, tuple(config.summary), device="cpu")
+
log.info(f"Instantiating criterion <{config.criterion._target_}>")
loss_fn: Type[nn.Module] = hydra.utils.instantiate(config.criterion)