summaryrefslogtreecommitdiff
path: root/text_recognizer/models
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-05 20:47:55 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-05 20:47:55 +0200
commit9ae5fa1a88899180f88ddb14d4cef457ceb847e5 (patch)
tree4fe2bcd82553c8062eb0908ae6442c123addf55d /text_recognizer/models
parent9e54591b7e342edc93b0bb04809a0f54045c6a15 (diff)
Add new training loop with PyTorch Lightning, remove stale files
Diffstat (limited to 'text_recognizer/models')
-rw-r--r--text_recognizer/models/base.py20
1 files changed, 10 insertions, 10 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 46e5136..2d6e435 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -1,5 +1,5 @@
"""Base PyTorch Lightning model."""
-from typing import Any, Dict, Tuple, Type
+from typing import Any, Dict, List, Tuple, Type
import madgrad
import pytorch_lightning as pl
@@ -40,7 +40,7 @@ class LitBaseModel(pl.LightningModule):
args = {} or criterion_args["args"]
return getattr(nn, criterion_args["type"])(**args)
- def configure_optimizer(self) -> Dict[str, Any]:
+ def configure_optimizer(self) -> Tuple[List[type], List[Dict[str, Any]]]:
"""Configures optimizer and lr scheduler."""
args = {} or self.optimizer_args["args"]
if self.optimizer_args["type"] == "MADGRAD":
@@ -48,15 +48,15 @@ class LitBaseModel(pl.LightningModule):
else:
optimizer = getattr(torch.optim, self.optimizer_args["type"])(**args)
+ scheduler = {"monitor": self.monitor}
args = {} or self.lr_scheduler_args["args"]
- scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])(
- **args
- )
- return {
- "optimizer": optimizer,
- "lr_scheduler": scheduler,
- "monitor": self.monitor,
- }
+ if "interval" in args:
+ scheduler["interval"] = args.pop("interval")
+
+ scheduler["scheduler"] = getattr(
+ torch.optim.lr_scheduler, self.lr_scheduler_args["type"]
+ )(**args)
+ return [optimizer], [scheduler]
def forward(self, data: Tensor) -> Tensor:
"""Feedforward pass."""