diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-08 23:38:03 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-08 23:38:03 +0200 |
commit | e388cd95c77d37a51324cff9d84a809421bf97d3 (patch) | |
tree | d585545f85d03ea8a6907daba254821fddeb1589 /text_recognizer/models | |
parent | f4629a0d4149d5870c9fd8ce83ff5d391bd7ddd3 (diff) |
Bug fixes word pieces
Diffstat (limited to 'text_recognizer/models')
-rw-r--r-- | text_recognizer/models/base.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 3c1919e..0928e6c 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -49,7 +49,7 @@ class LitBaseModel(pl.LightningModule): optimizer_class = getattr(torch.optim, self._optimizer.type) return optimizer_class(params=self.parameters(), **args) - def _configure_lr_scheduler(self) -> Dict[str, Any]: + def _configure_lr_scheduler(self, optimizer: Type[torch.optim.Optimizer]) -> Dict[str, Any]: """Configures the lr scheduler.""" scheduler = {"monitor": self.monitor} args = {} or self._lr_scheduler.args @@ -59,13 +59,13 @@ class LitBaseModel(pl.LightningModule): scheduler["scheduler"] = getattr( torch.optim.lr_scheduler, self._lr_scheduler.type - )(**args) + )(optimizer, **args) return scheduler def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]: """Configures optimizer and lr scheduler.""" optimizer = self._configure_optimizer() - scheduler = self._configure_lr_scheduler() + scheduler = self._configure_lr_scheduler(optimizer) return [optimizer], [scheduler] |