diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-22 08:15:58 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-22 08:15:58 +0200 |
commit | 1ca8b0b9e0613c1e02f6a5d8b49e20c4d6916412 (patch) | |
tree | 5e610ac459c9b254f8826e92372346f01f8e2412 /text_recognizer/models/base.py | |
parent | ffa4be4bf4e3758e01d52a9c1f354a05a90b93de (diff) |
Fixed training script, able to train vqvae
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r-- | text_recognizer/models/base.py | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index aeda039..88ffde6 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -40,6 +40,15 @@ class LitBaseModel(pl.LightningModule): args = {} or criterion.args return getattr(nn, criterion.type)(**args) + def optimizer_zero_grad( + self, + epoch: int, + batch_idx: int, + optimizer: Type[torch.optim.Optimizer], + optimizer_idx: int, + ) -> None: + optimizer.zero_grad(set_to_none=True) + def _configure_optimizer(self) -> Type[torch.optim.Optimizer]: """Configures the optimizer.""" args = {} or self._optimizer.args |