diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-30 23:02:48 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-30 23:02:48 +0200 |
commit | 3c41a7061e5bf6c648e7c7216d64c29dc342a0ca (patch) | |
tree | a5788389a063a6c0d955c91c576e7372aa788bd9 /text_recognizer/criterions/vqgan_loss.py | |
parent | 08d73ff01e5e0590e11d5d44a3c85a16bca76ce5 (diff) |
Rename vqloss to commitment loss
Diffstat (limited to 'text_recognizer/criterions/vqgan_loss.py')
-rw-r--r-- | text_recognizer/criterions/vqgan_loss.py | 22 |
1 files changed, 11 insertions, 11 deletions
diff --git a/text_recognizer/criterions/vqgan_loss.py b/text_recognizer/criterions/vqgan_loss.py index 7af1a55..9d1cddd 100644 --- a/text_recognizer/criterions/vqgan_loss.py +++ b/text_recognizer/criterions/vqgan_loss.py @@ -24,7 +24,7 @@ class VQGANLoss(nn.Module): self, reconstruction_loss: nn.L1Loss, discriminator: NLayerDiscriminator, - vq_loss_weight: float = 1.0, + commitment_weight: float = 1.0, discriminator_weight: float = 1.0, discriminator_factor: float = 1.0, discriminator_iter_start: int = 1000, @@ -32,7 +32,7 @@ class VQGANLoss(nn.Module): super().__init__() self.reconstruction_loss = reconstruction_loss self.discriminator = discriminator - self.vq_loss_weight = vq_loss_weight + self.commitment_weight = commitment_weight self.discriminator_weight = discriminator_weight self.discriminator_factor = discriminator_factor self.discriminator_iter_start = discriminator_iter_start @@ -61,20 +61,18 @@ class VQGANLoss(nn.Module): self, data: Tensor, reconstructions: Tensor, - vq_loss: Tensor, + commitment_loss: Tensor, decoder_last_layer: Tensor, optimizer_idx: int, global_step: int, stage: str, ) -> Optional[Tuple]: """Calculates the VQGAN loss.""" - rec_loss: Tensor = self.reconstruction_loss( - data.contiguous(), reconstructions.contiguous() - ) + rec_loss: Tensor = self.reconstruction_loss(reconstructions, data) # GAN part. if optimizer_idx == 0: - logits_fake = self.discriminator(reconstructions.contiguous()) + logits_fake = self.discriminator(reconstructions) g_loss = -torch.mean(logits_fake) if self.training: @@ -93,19 +91,21 @@ class VQGANLoss(nn.Module): ) loss: Tensor = ( - rec_loss + d_factor * d_weight * g_loss + self.vq_loss_weight * vq_loss + rec_loss + + d_factor * d_weight * g_loss + + self.commitment_weight * commitment_loss ) log = { f"{stage}/total_loss": loss, - f"{stage}/vq_loss": vq_loss, + f"{stage}/commitment_loss": commitment_loss, f"{stage}/rec_loss": rec_loss, f"{stage}/g_loss": g_loss, } return loss, log if optimizer_idx == 1: - logits_fake = self.discriminator(reconstructions.contiguous().detach()) - logits_real = self.discriminator(data.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.detach()) + logits_real = self.discriminator(data.detach()) d_factor = _adopt_weight( self.discriminator_factor, |