summaryrefslogtreecommitdiff
path: root/text_recognizer/callbacks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-05 23:05:25 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-05 23:05:25 +0200
commit4d1f2cef39688871d2caafce42a09316381a27ae (patch)
tree0f4385969e7df6d7d313cd5910bde9a7475ca027 /text_recognizer/callbacks
parentf0481decdad9afb52494e9e95996deef843ef233 (diff)
Refactor with attr, working on cnn+transformer network
Diffstat (limited to 'text_recognizer/callbacks')
-rw-r--r--text_recognizer/callbacks/__init__.py1
-rw-r--r--text_recognizer/callbacks/wandb_callbacks.py8
2 files changed, 5 insertions, 4 deletions
diff --git a/text_recognizer/callbacks/__init__.py b/text_recognizer/callbacks/__init__.py
index e69de29..82d8ce3 100644
--- a/text_recognizer/callbacks/__init__.py
+++ b/text_recognizer/callbacks/__init__.py
@@ -0,0 +1 @@
+"""Module for PyTorch Lightning callbacks."""
diff --git a/text_recognizer/callbacks/wandb_callbacks.py b/text_recognizer/callbacks/wandb_callbacks.py
index 900c3b1..4186b4a 100644
--- a/text_recognizer/callbacks/wandb_callbacks.py
+++ b/text_recognizer/callbacks/wandb_callbacks.py
@@ -29,7 +29,7 @@ class WatchModel(Callback):
log: str = attr.ib(default="gradients")
log_freq: int = attr.ib(default=100)
- def __attrs_pre_init__(self):
+ def __attrs_pre_init__(self) -> None:
super().__init__()
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
@@ -44,7 +44,7 @@ class UploadCodeAsArtifact(Callback):
project_dir: Path = attr.ib(converter=Path)
- def __attrs_pre_init__(self):
+ def __attrs_pre_init__(self) -> None:
super().__init__()
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
@@ -65,7 +65,7 @@ class UploadCheckpointAsArtifact(Callback):
ckpt_dir: Path = attr.ib(converter=Path)
upload_best_only: bool = attr.ib()
- def __attrs_pre_init__(self):
+ def __attrs_pre_init__(self) -> None:
super().__init__()
def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
@@ -90,7 +90,7 @@ class LogTextPredictions(Callback):
num_samples: int = attr.ib(default=8)
ready: bool = attr.ib(default=True)
- def __attrs_pre_init__(self):
+ def __attrs_pre_init__(self) -> None:
super().__init__()
def on_sanity_check_start(