diff options
Diffstat (limited to 'text_recognizer')
| -rw-r--r-- | text_recognizer/criterions/vqgan_loss.py | 23 | ||||
| -rw-r--r-- | text_recognizer/models/base.py | 8 | ||||
| -rw-r--r-- | text_recognizer/models/vqgan.py | 24 | 
3 files changed, 15 insertions, 40 deletions
diff --git a/text_recognizer/criterions/vqgan_loss.py b/text_recognizer/criterions/vqgan_loss.py index 8bb568f..87f0f1c 100644 --- a/text_recognizer/criterions/vqgan_loss.py +++ b/text_recognizer/criterions/vqgan_loss.py @@ -1,5 +1,5 @@  """VQGAN loss for PyTorch Lightning.""" -from typing import Dict +from typing import Dict, Optional  from click.types import Tuple  import torch @@ -40,9 +40,9 @@ class VQGANLoss(nn.Module):          vq_loss: Tensor,          optimizer_idx: int,          stage: str, -    ) -> Tuple[Tensor, Dict[str, Tensor]]: +    ) -> Optional[Tuple]:          """Calculates the VQGAN loss.""" -        rec_loss = self.reconstruction_loss( +        rec_loss: Tensor = self.reconstruction_loss(              data.contiguous(), reconstructions.contiguous()          ) @@ -51,13 +51,13 @@ class VQGANLoss(nn.Module):              logits_fake = self.discriminator(reconstructions.contiguous())              g_loss = -torch.mean(logits_fake) -            loss = ( +            loss: Tensor = (                  rec_loss                  + self.discriminator_weight * g_loss                  + self.vq_loss_weight * vq_loss              )              log = { -                f"{stage}/loss": loss, +                f"{stage}/total_loss": loss,                  f"{stage}/vq_loss": vq_loss,                  f"{stage}/rec_loss": rec_loss,                  f"{stage}/g_loss": g_loss, @@ -68,18 +68,11 @@ class VQGANLoss(nn.Module):              logits_fake = self.discriminator(reconstructions.contiguous().detach())              logits_real = self.discriminator(data.contiguous().detach()) -            d_loss = self.adversarial_loss( +            d_loss = self.discriminator_weight * self.adversarial_loss(                  logits_real=logits_real, logits_fake=logits_fake              ) -            loss = ( -                rec_loss -                + self.discriminator_weight * d_loss -                + self.vq_loss_weight * vq_loss -            ) +              log = { -                f"{stage}/loss": loss, -                f"{stage}/vq_loss": vq_loss, -                f"{stage}/rec_loss": rec_loss,                  f"{stage}/d_loss": d_loss,              } -            return loss, log +            return d_loss, log diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 8b68ed9..94dbde5 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -49,12 +49,14 @@ class BaseLitModel(LightningModule):          """Configures the optimizer."""          optimizers = []          for optimizer_config in self.optimizer_configs.values(): -            network = getattr(self, optimizer_config.parameters) +            module = self +            for m in str(optimizer_config.parameters).split("."): +                module = getattr(module, m)              del optimizer_config.parameters              log.info(f"Instantiating optimizer <{optimizer_config._target_}>")              optimizers.append(                  hydra.utils.instantiate( -                    self.optimizer_config, params=network.parameters() +                    optimizer_config, params=module.parameters()                  )              )          return optimizers @@ -92,7 +94,7 @@ class BaseLitModel(LightningModule):      ) -> Tuple[List[Type[torch.optim.Optimizer]], List[Dict[str, Any]]]:          """Configures optimizer and lr scheduler."""          optimizers = self._configure_optimizer() -        schedulers = self._configure_lr_scheduler(optimizers) +        schedulers = self._configure_lr_schedulers(optimizers)          return optimizers, schedulers      def forward(self, data: Tensor) -> Tensor: diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py index 8ff65cc..80653b6 100644 --- a/text_recognizer/models/vqgan.py +++ b/text_recognizer/models/vqgan.py @@ -9,7 +9,7 @@ from text_recognizer.criterions.vqgan_loss import VQGANLoss  @attr.s(auto_attribs=True, eq=False) -class VQVAELitModel(BaseLitModel): +class VQGANLitModel(BaseLitModel):      """A PyTorch Lightning model for transformer networks."""      loss_fn: VQGANLoss = attr.ib() @@ -26,7 +26,6 @@ class VQVAELitModel(BaseLitModel):          data, _ = batch          reconstructions, vq_loss = self(data) -        loss = self.loss_fn(reconstructions, data)          if optimizer_idx == 0:              loss, log = self.loss_fn( @@ -81,14 +80,6 @@ class VQVAELitModel(BaseLitModel):          self.log(              "val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True          ) -        self.log( -            "val/rec_loss", -            log["val/rec_loss"], -            prog_bar=True, -            logger=True, -            on_step=True, -            on_epoch=True, -        )          self.log_dict(log)          _, log = self.loss_fn( @@ -105,24 +96,13 @@ class VQVAELitModel(BaseLitModel):          data, _ = batch          reconstructions, vq_loss = self(data) -        loss, log = self.loss_fn( +        _, log = self.loss_fn(              data=data,              reconstructions=reconstructions,              vq_loss=vq_loss,              optimizer_idx=0,              stage="test",          ) -        self.log( -            "test/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True -        ) -        self.log( -            "test/rec_loss", -            log["test/rec_loss"], -            prog_bar=True, -            logger=True, -            on_step=True, -            on_epoch=True, -        )          self.log_dict(log)          _, log = self.loss_fn(  |