diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-05 23:05:25 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-05 23:05:25 +0200 |
commit | 4d1f2cef39688871d2caafce42a09316381a27ae (patch) | |
tree | 0f4385969e7df6d7d313cd5910bde9a7475ca027 /text_recognizer/callbacks | |
parent | f0481decdad9afb52494e9e95996deef843ef233 (diff) |
Refactor with attr, working on cnn+transformer network
Diffstat (limited to 'text_recognizer/callbacks')
-rw-r--r-- | text_recognizer/callbacks/__init__.py | 1 | ||||
-rw-r--r-- | text_recognizer/callbacks/wandb_callbacks.py | 8 |
2 files changed, 5 insertions, 4 deletions
diff --git a/text_recognizer/callbacks/__init__.py b/text_recognizer/callbacks/__init__.py index e69de29..82d8ce3 100644 --- a/text_recognizer/callbacks/__init__.py +++ b/text_recognizer/callbacks/__init__.py @@ -0,0 +1 @@ +"""Module for PyTorch Lightning callbacks.""" diff --git a/text_recognizer/callbacks/wandb_callbacks.py b/text_recognizer/callbacks/wandb_callbacks.py index 900c3b1..4186b4a 100644 --- a/text_recognizer/callbacks/wandb_callbacks.py +++ b/text_recognizer/callbacks/wandb_callbacks.py @@ -29,7 +29,7 @@ class WatchModel(Callback): log: str = attr.ib(default="gradients") log_freq: int = attr.ib(default=100) - def __attrs_pre_init__(self): + def __attrs_pre_init__(self) -> None: super().__init__() def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: @@ -44,7 +44,7 @@ class UploadCodeAsArtifact(Callback): project_dir: Path = attr.ib(converter=Path) - def __attrs_pre_init__(self): + def __attrs_pre_init__(self) -> None: super().__init__() def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: @@ -65,7 +65,7 @@ class UploadCheckpointAsArtifact(Callback): ckpt_dir: Path = attr.ib(converter=Path) upload_best_only: bool = attr.ib() - def __attrs_pre_init__(self): + def __attrs_pre_init__(self) -> None: super().__init__() def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None: @@ -90,7 +90,7 @@ class LogTextPredictions(Callback): num_samples: int = attr.ib(default=8) ready: bool = attr.ib(default=True) - def __attrs_pre_init__(self): + def __attrs_pre_init__(self) -> None: super().__init__() def on_sanity_check_start( |