summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-29 21:40:19 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-29 21:40:19 +0200
commit2f1bb639fd5bb6b510af85fb597e9322abc17bc0 (patch)
tree3269155b33f33bf2964dc1bdff34d7929b3227f2 /text_recognizer
parentda7d2171c818afefb3bad3cd66ce85fddd519c1c (diff)
Remove uploading of code to Wandb, upload config instead
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/models/base.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 94dbde5..56d4ca5 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -1,5 +1,5 @@
"""Base PyTorch Lightning model."""
-from typing import Any, Dict, List, Tuple, Type
+from typing import Any, Dict, List, Optional, Tuple, Type
import attr
import hydra
@@ -25,7 +25,7 @@ class BaseLitModel(LightningModule):
mapping: Type[AbstractMapping] = attr.ib()
loss_fn: Type[nn.Module] = attr.ib()
optimizer_configs: DictConfig = attr.ib()
- lr_scheduler_configs: DictConfig = attr.ib()
+ lr_scheduler_configs: Optional[DictConfig] = attr.ib()
train_acc: torchmetrics.Accuracy = attr.ib(
init=False, default=torchmetrics.Accuracy()
)
@@ -55,9 +55,7 @@ class BaseLitModel(LightningModule):
del optimizer_config.parameters
log.info(f"Instantiating optimizer <{optimizer_config._target_}>")
optimizers.append(
- hydra.utils.instantiate(
- optimizer_config, params=module.parameters()
- )
+ hydra.utils.instantiate(optimizer_config, params=module.parameters())
)
return optimizers
@@ -65,6 +63,8 @@ class BaseLitModel(LightningModule):
self, optimizers: List[Type[torch.optim.Optimizer]]
) -> List[Dict[str, Any]]:
"""Configures the lr scheduler."""
+ if None in self.lr_scheduler_configs:
+ return []
schedulers = []
for optimizer, lr_scheduler_config in zip(
optimizers, self.lr_scheduler_configs.values()