From da7d2171c818afefb3bad3cd66ce85fddd519c1c Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 15 Aug 2021 21:15:31 +0200 Subject: Updates to VQGAN loss --- text_recognizer/criterions/vqgan_loss.py | 53 +++++++++++++++++++++++++++++--- text_recognizer/models/vqgan.py | 20 ++++++++---- 2 files changed, 63 insertions(+), 10 deletions(-) (limited to 'text_recognizer') diff --git a/text_recognizer/criterions/vqgan_loss.py b/text_recognizer/criterions/vqgan_loss.py index 87f0f1c..cfd507d 100644 --- a/text_recognizer/criterions/vqgan_loss.py +++ b/text_recognizer/criterions/vqgan_loss.py @@ -9,6 +9,14 @@ import torch.nn.functional as F from text_recognizer.criterions.n_layer_discriminator import NLayerDiscriminator +def adopt_weight( + weight: Tensor, global_step: int, threshold: int = 0, value: float = 0.0 +) -> float: + if global_step < threshold: + weight = value + return weight + + class VQGANLoss(nn.Module): """VQGAN loss.""" @@ -18,12 +26,16 @@ class VQGANLoss(nn.Module): discriminator: NLayerDiscriminator, vq_loss_weight: float = 1.0, discriminator_weight: float = 1.0, + discriminator_factor: float = 1.0, + discriminator_iter_start: int = 1000, ) -> None: super().__init__() self.reconstruction_loss = reconstruction_loss self.discriminator = discriminator self.vq_loss_weight = vq_loss_weight self.discriminator_weight = discriminator_weight + self.discriminator_factor = discriminator_factor + self.discriminator_iter_start = discriminator_iter_start @staticmethod def adversarial_loss(logits_real: Tensor, logits_fake: Tensor) -> Tensor: @@ -33,12 +45,26 @@ class VQGANLoss(nn.Module): d_loss = (loss_real + loss_fake) / 2.0 return d_loss + def _adaptive_weight( + self, rec_loss: Tensor, g_loss: Tensor, decoder_last_layer: Tensor + ) -> Tensor: + rec_grads = torch.autograd.grad( + rec_loss, decoder_last_layer, retain_graph=True + )[0] + g_grads = torch.autograd.grad(g_loss, decoder_last_layer, retain_graph=True)[0] + d_weight = torch.norm(rec_grads) / (torch.norm(g_grads) + 1.0e-4) + d_weight = torch.clamp(d_weight, 0.0, 1.0e4).detach() + d_weight *= self.discriminator_weight + return d_weight + def forward( self, data: Tensor, reconstructions: Tensor, vq_loss: Tensor, + decoder_last_layer: Tensor, optimizer_idx: int, + global_step: int, stage: str, ) -> Optional[Tuple]: """Calculates the VQGAN loss.""" @@ -51,10 +77,23 @@ class VQGANLoss(nn.Module): logits_fake = self.discriminator(reconstructions.contiguous()) g_loss = -torch.mean(logits_fake) + if self.training: + d_weight = self._adaptive_weight( + rec_loss=rec_loss, + g_loss=g_loss, + decoder_last_layer=decoder_last_layer, + ) + else: + d_weight = torch.tensor(0.0) + + d_factor = adopt_weight( + self.discriminator_factor, + global_step=global_step, + threshold=self.discriminator_iter_start, + ) + loss: Tensor = ( - rec_loss - + self.discriminator_weight * g_loss - + self.vq_loss_weight * vq_loss + rec_loss + d_factor * d_weight * g_loss + self.vq_loss_weight * vq_loss ) log = { f"{stage}/total_loss": loss, @@ -68,7 +107,13 @@ class VQGANLoss(nn.Module): logits_fake = self.discriminator(reconstructions.contiguous().detach()) logits_real = self.discriminator(data.contiguous().detach()) - d_loss = self.discriminator_weight * self.adversarial_loss( + d_factor = adopt_weight( + self.discriminator_factor, + global_step=global_step, + threshold=self.discriminator_iter_start, + ) + + d_loss = d_factor * self.adversarial_loss( logits_real=logits_real, logits_fake=logits_fake ) diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py index 7c707b1..2f67b35 100644 --- a/text_recognizer/models/vqgan.py +++ b/text_recognizer/models/vqgan.py @@ -32,13 +32,13 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=optimizer_idx, + global_step=self.global_step, stage="train", ) self.log( - "train/loss", - loss, - prog_bar=True, + "train/loss", loss, prog_bar=True, ) self.log_dict(log, logger=True, on_step=True, on_epoch=True) return loss @@ -48,13 +48,13 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=optimizer_idx, + global_step=self.global_step, stage="train", ) self.log( - "train/discriminator_loss", - loss, - prog_bar=True, + "train/discriminator_loss", loss, prog_bar=True, ) self.log_dict(log, logger=True, on_step=True, on_epoch=True) return loss @@ -68,7 +68,9 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=0, + global_step=self.global_step, stage="val", ) self.log( @@ -80,7 +82,9 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=1, + global_step=self.global_step, stage="val", ) self.log_dict(log) @@ -94,7 +98,9 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=0, + global_step=self.global_step, stage="test", ) self.log_dict(log) @@ -103,7 +109,9 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=1, + global_step=self.global_step, stage="test", ) self.log_dict(log) -- cgit v1.2.3-70-g09d2