diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-08 19:59:55 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-08 19:59:55 +0200 |
commit | 240f5e9f20032e82515fa66ce784619527d1041e (patch) | |
tree | b002d28bbfc9abe9b6af090f7db60bea0aeed6e8 /text_recognizer/criterions | |
parent | d12f70402371dda586d457af2a3df7fb5b3130ad (diff) |
Add VQGAN and loss function
Diffstat (limited to 'text_recognizer/criterions')
-rw-r--r-- | text_recognizer/criterions/n_layer_discriminator.py | 58 | ||||
-rw-r--r-- | text_recognizer/criterions/vqgan_loss.py | 85 |
2 files changed, 143 insertions, 0 deletions
diff --git a/text_recognizer/criterions/n_layer_discriminator.py b/text_recognizer/criterions/n_layer_discriminator.py new file mode 100644 index 0000000..e5f8449 --- /dev/null +++ b/text_recognizer/criterions/n_layer_discriminator.py @@ -0,0 +1,58 @@ +"""Pix2pix discriminator loss.""" +from torch import nn, Tensor + +from text_recognizer.networks.vqvae.norm import Normalize + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator loss in Pix2Pix.""" + + def __init__( + self, in_channels: int = 1, num_channels: int = 32, num_layers: int = 3 + ) -> None: + super().__init__() + self.in_channels = in_channels + self.num_channels = num_channels + self.num_layers = num_layers + self.discriminator = self._build_discriminator() + + def _build_discriminator(self) -> nn.Sequential: + """Builds discriminator.""" + discriminator = [ + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.num_channels, + kernel_size=4, + stride=2, + padding=1, + ), + nn.Mish(inplace=True), + ] + in_channels = self.num_channels + for n in range(1, self.num_layers): + discriminator += [ + nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels * n, + kernel_size=4, + stride=2, + padding=1, + ), + Normalize(num_channels=in_channels * n), + nn.Mish(inplace=True), + ] + in_channels *= n + + discriminator += [ + nn.Conv2d( + in_channels=self.num_channels * (self.num_layers - 1), + out_channels=1, + kernel_size=4, + padding=1, + ) + ] + return nn.Sequential(*discriminator) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass through discriminator.""" + return self.discriminator(x) diff --git a/text_recognizer/criterions/vqgan_loss.py b/text_recognizer/criterions/vqgan_loss.py new file mode 100644 index 0000000..8bb568f --- /dev/null +++ b/text_recognizer/criterions/vqgan_loss.py @@ -0,0 +1,85 @@ +"""VQGAN loss for PyTorch Lightning.""" +from typing import Dict +from click.types import Tuple + +import torch +from torch import nn, Tensor +import torch.nn.functional as F + +from text_recognizer.criterions.n_layer_discriminator import NLayerDiscriminator + + +class VQGANLoss(nn.Module): + """VQGAN loss.""" + + def __init__( + self, + reconstruction_loss: nn.L1Loss, + discriminator: NLayerDiscriminator, + vq_loss_weight: float = 1.0, + discriminator_weight: float = 1.0, + ) -> None: + super().__init__() + self.reconstruction_loss = reconstruction_loss + self.discriminator = discriminator + self.vq_loss_weight = vq_loss_weight + self.discriminator_weight = discriminator_weight + + @staticmethod + def adversarial_loss(logits_real: Tensor, logits_fake: Tensor) -> Tensor: + """Calculates the adversarial loss.""" + loss_real = torch.mean(F.relu(1.0 - logits_real)) + loss_fake = torch.mean(F.relu(1.0 + logits_fake)) + d_loss = (loss_real + loss_fake) / 2.0 + return d_loss + + def forward( + self, + data: Tensor, + reconstructions: Tensor, + vq_loss: Tensor, + optimizer_idx: int, + stage: str, + ) -> Tuple[Tensor, Dict[str, Tensor]]: + """Calculates the VQGAN loss.""" + rec_loss = self.reconstruction_loss( + data.contiguous(), reconstructions.contiguous() + ) + + # GAN part. + if optimizer_idx == 0: + logits_fake = self.discriminator(reconstructions.contiguous()) + g_loss = -torch.mean(logits_fake) + + loss = ( + rec_loss + + self.discriminator_weight * g_loss + + self.vq_loss_weight * vq_loss + ) + log = { + f"{stage}/loss": loss, + f"{stage}/vq_loss": vq_loss, + f"{stage}/rec_loss": rec_loss, + f"{stage}/g_loss": g_loss, + } + return loss, log + + if optimizer_idx == 1: + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + logits_real = self.discriminator(data.contiguous().detach()) + + d_loss = self.adversarial_loss( + logits_real=logits_real, logits_fake=logits_fake + ) + loss = ( + rec_loss + + self.discriminator_weight * d_loss + + self.vq_loss_weight * vq_loss + ) + log = { + f"{stage}/loss": loss, + f"{stage}/vq_loss": vq_loss, + f"{stage}/rec_loss": rec_loss, + f"{stage}/d_loss": d_loss, + } + return loss, log |