diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-08 21:43:39 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-08 21:43:39 +0200 |
commit | 82f4acabe24e5171c40afa2939a4777ba87bcc30 (patch) | |
tree | 4d327fa26e4662a0447a66375442a9adeb13ea3d /text_recognizer/models/base.py | |
parent | 240f5e9f20032e82515fa66ce784619527d1041e (diff) |
Add training of VQGAN
Diffstat (limited to 'text_recognizer/models/base.py')
-rw-r--r-- | text_recognizer/models/base.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 8b68ed9..94dbde5 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -49,12 +49,14 @@ class BaseLitModel(LightningModule): """Configures the optimizer.""" optimizers = [] for optimizer_config in self.optimizer_configs.values(): - network = getattr(self, optimizer_config.parameters) + module = self + for m in str(optimizer_config.parameters).split("."): + module = getattr(module, m) del optimizer_config.parameters log.info(f"Instantiating optimizer <{optimizer_config._target_}>") optimizers.append( hydra.utils.instantiate( - self.optimizer_config, params=network.parameters() + optimizer_config, params=module.parameters() ) ) return optimizers @@ -92,7 +94,7 @@ class BaseLitModel(LightningModule): ) -> Tuple[List[Type[torch.optim.Optimizer]], List[Dict[str, Any]]]: """Configures optimizer and lr scheduler.""" optimizers = self._configure_optimizer() - schedulers = self._configure_lr_scheduler(optimizers) + schedulers = self._configure_lr_schedulers(optimizers) return optimizers, schedulers def forward(self, data: Tensor) -> Tensor: |