From 9ae5fa1a88899180f88ddb14d4cef457ceb847e5 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Mon, 5 Apr 2021 20:47:55 +0200
Subject: Add new training loop with PyTorch Lightning, remove stale files

---
 text_recognizer/models/base.py | 20 ++++++++++----------
 1 file changed, 10 insertions(+), 10 deletions(-)

(limited to 'text_recognizer/models')

diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 46e5136..2d6e435 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -1,5 +1,5 @@
 """Base PyTorch Lightning model."""
-from typing import Any, Dict, Tuple, Type
+from typing import Any, Dict, List, Tuple, Type
 
 import madgrad
 import pytorch_lightning as pl
@@ -40,7 +40,7 @@ class LitBaseModel(pl.LightningModule):
         args = {} or criterion_args["args"]
         return getattr(nn, criterion_args["type"])(**args)
 
-    def configure_optimizer(self) -> Dict[str, Any]:
+    def configure_optimizer(self) -> Tuple[List[type], List[Dict[str, Any]]]:
         """Configures optimizer and lr scheduler."""
         args = {} or self.optimizer_args["args"]
         if self.optimizer_args["type"] == "MADGRAD":
@@ -48,15 +48,15 @@ class LitBaseModel(pl.LightningModule):
         else:
             optimizer = getattr(torch.optim, self.optimizer_args["type"])(**args)
 
+        scheduler = {"monitor": self.monitor}
         args = {} or self.lr_scheduler_args["args"]
-        scheduler = getattr(torch.optim.lr_scheduler, self.lr_scheduler_args["type"])(
-            **args
-        )
-        return {
-            "optimizer": optimizer,
-            "lr_scheduler": scheduler,
-            "monitor": self.monitor,
-        }
+        if "interval" in args:
+            scheduler["interval"] = args.pop("interval")
+
+        scheduler["scheduler"] = getattr(
+            torch.optim.lr_scheduler, self.lr_scheduler_args["type"]
+        )(**args)
+        return [optimizer], [scheduler]
 
     def forward(self, data: Tensor) -> Tensor:
         """Feedforward pass."""
-- 
cgit v1.2.3-70-g09d2