summaryrefslogtreecommitdiff
path: root/text_recognizer/model/base.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2023-09-02 01:52:44 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2023-09-02 01:52:44 +0200
commit579f3edc3e20ddbe8207ee0c4189a270b2dfedc1 (patch)
treedb84fd73e62ee4705088953ab55f871c86603dca /text_recognizer/model/base.py
parent09f9eab02ef40b1ca26e4693ad77f1f2df79a945 (diff)
Refactor lit models
Diffstat (limited to 'text_recognizer/model/base.py')
-rw-r--r--text_recognizer/model/base.py10
1 files changed, 10 insertions, 0 deletions
diff --git a/text_recognizer/model/base.py b/text_recognizer/model/base.py
index 1cff796..adcb8da 100644
--- a/text_recognizer/model/base.py
+++ b/text_recognizer/model/base.py
@@ -94,3 +94,13 @@ class LitBase(L.LightningModule):
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Test step."""
pass
+
+ def is_logged_batch(self) -> bool:
+ if self.trainer is None:
+ return False
+ else:
+ return self.trainer._logger_connector.should_update_logs
+
+ def add_on_first_batch(self, metrics: dict, output: dict, batch_idx: int) -> None:
+ if batch_idx == 0:
+ output.update(metrics)