diff options
Diffstat (limited to 'text_recognizer/criterions')
-rw-r--r-- | text_recognizer/criterions/vqgan_loss.py | 23 |
1 files changed, 8 insertions, 15 deletions
diff --git a/text_recognizer/criterions/vqgan_loss.py b/text_recognizer/criterions/vqgan_loss.py index 8bb568f..87f0f1c 100644 --- a/text_recognizer/criterions/vqgan_loss.py +++ b/text_recognizer/criterions/vqgan_loss.py @@ -1,5 +1,5 @@ """VQGAN loss for PyTorch Lightning.""" -from typing import Dict +from typing import Dict, Optional from click.types import Tuple import torch @@ -40,9 +40,9 @@ class VQGANLoss(nn.Module): vq_loss: Tensor, optimizer_idx: int, stage: str, - ) -> Tuple[Tensor, Dict[str, Tensor]]: + ) -> Optional[Tuple]: """Calculates the VQGAN loss.""" - rec_loss = self.reconstruction_loss( + rec_loss: Tensor = self.reconstruction_loss( data.contiguous(), reconstructions.contiguous() ) @@ -51,13 +51,13 @@ class VQGANLoss(nn.Module): logits_fake = self.discriminator(reconstructions.contiguous()) g_loss = -torch.mean(logits_fake) - loss = ( + loss: Tensor = ( rec_loss + self.discriminator_weight * g_loss + self.vq_loss_weight * vq_loss ) log = { - f"{stage}/loss": loss, + f"{stage}/total_loss": loss, f"{stage}/vq_loss": vq_loss, f"{stage}/rec_loss": rec_loss, f"{stage}/g_loss": g_loss, @@ -68,18 +68,11 @@ class VQGANLoss(nn.Module): logits_fake = self.discriminator(reconstructions.contiguous().detach()) logits_real = self.discriminator(data.contiguous().detach()) - d_loss = self.adversarial_loss( + d_loss = self.discriminator_weight * self.adversarial_loss( logits_real=logits_real, logits_fake=logits_fake ) - loss = ( - rec_loss - + self.discriminator_weight * d_loss - + self.vq_loss_weight * vq_loss - ) + log = { - f"{stage}/loss": loss, - f"{stage}/vq_loss": vq_loss, - f"{stage}/rec_loss": rec_loss, f"{stage}/d_loss": d_loss, } - return loss, log + return d_loss, log |