diff options
Diffstat (limited to 'text_recognizer/criterion')
-rw-r--r-- | text_recognizer/criterion/n_layer_discriminator.py | 59 | ||||
-rw-r--r-- | text_recognizer/criterion/vqgan_loss.py | 123 |
2 files changed, 0 insertions, 182 deletions
diff --git a/text_recognizer/criterion/n_layer_discriminator.py b/text_recognizer/criterion/n_layer_discriminator.py deleted file mode 100644 index a9f47f9..0000000 --- a/text_recognizer/criterion/n_layer_discriminator.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Pix2pix discriminator loss.""" -from torch import nn, Tensor - -from text_recognizer.networks.vqvae.norm import Normalize - - -class NLayerDiscriminator(nn.Module): - """Defines a PatchGAN discriminator loss in Pix2Pix.""" - - def __init__( - self, in_channels: int = 1, num_channels: int = 32, num_layers: int = 3 - ) -> None: - super().__init__() - self.in_channels = in_channels - self.num_channels = num_channels - self.num_layers = num_layers - self.discriminator = self._build_discriminator() - - def _build_discriminator(self) -> nn.Sequential: - """Builds discriminator.""" - discriminator = [ - nn.Sigmoid(), - nn.Conv2d( - in_channels=self.in_channels, - out_channels=self.num_channels, - kernel_size=4, - stride=2, - padding=1, - ), - nn.Mish(inplace=True), - ] - in_channels = self.num_channels - for n in range(1, self.num_layers): - discriminator += [ - nn.Conv2d( - in_channels=in_channels, - out_channels=in_channels * n, - kernel_size=4, - stride=2, - padding=1, - ), - # Normalize(num_channels=in_channels * n), - nn.Mish(inplace=True), - ] - in_channels *= n - - discriminator += [ - nn.Conv2d( - in_channels=self.num_channels * (self.num_layers - 1), - out_channels=1, - kernel_size=4, - padding=1, - ) - ] - return nn.Sequential(*discriminator) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass through discriminator.""" - return self.discriminator(x) diff --git a/text_recognizer/criterion/vqgan_loss.py b/text_recognizer/criterion/vqgan_loss.py deleted file mode 100644 index 8e8b65b..0000000 --- a/text_recognizer/criterion/vqgan_loss.py +++ /dev/null @@ -1,123 +0,0 @@ -"""VQGAN loss for PyTorch Lightning.""" -from typing import Optional, Tuple - -import torch -from torch import nn, Tensor -import torch.nn.functional as F - -from text_recognizer.criterion.n_layer_discriminator import NLayerDiscriminator - - -def _adopt_weight( - weight: Tensor, global_step: int, threshold: int = 0, value: float = 0.0 -) -> float: - """Sets weight to value after the threshold is passed.""" - if global_step < threshold: - weight = value - return weight - - -class VQGANLoss(nn.Module): - """VQGAN loss.""" - - def __init__( - self, - reconstruction_loss: nn.L1Loss, - discriminator: NLayerDiscriminator, - commitment_weight: float = 1.0, - discriminator_weight: float = 1.0, - discriminator_factor: float = 1.0, - discriminator_iter_start: int = 1000, - ) -> None: - super().__init__() - self.reconstruction_loss = reconstruction_loss - self.discriminator = discriminator - self.commitment_weight = commitment_weight - self.discriminator_weight = discriminator_weight - self.discriminator_factor = discriminator_factor - self.discriminator_iter_start = discriminator_iter_start - - @staticmethod - def adversarial_loss(logits_real: Tensor, logits_fake: Tensor) -> Tensor: - """Calculates the adversarial loss.""" - loss_real = torch.mean(F.relu(1.0 - logits_real)) - loss_fake = torch.mean(F.relu(1.0 + logits_fake)) - d_loss = (loss_real + loss_fake) / 2.0 - return d_loss - - def _adaptive_weight( - self, rec_loss: Tensor, g_loss: Tensor, decoder_last_layer: Tensor - ) -> Tensor: - rec_grads = torch.autograd.grad( - rec_loss, decoder_last_layer, retain_graph=True - )[0] - g_grads = torch.autograd.grad(g_loss, decoder_last_layer, retain_graph=True)[0] - d_weight = torch.norm(rec_grads) / (torch.norm(g_grads) + 1.0e-4) - d_weight = torch.clamp(d_weight, 0.0, 1.0e4).detach() - d_weight *= self.discriminator_weight - return d_weight - - def forward( - self, - data: Tensor, - reconstructions: Tensor, - commitment_loss: Tensor, - decoder_last_layer: Tensor, - optimizer_idx: int, - global_step: int, - stage: str, - ) -> Optional[Tuple]: - """Calculates the VQGAN loss.""" - rec_loss: Tensor = self.reconstruction_loss(reconstructions, data) - - # GAN part. - if optimizer_idx == 0: - logits_fake = self.discriminator(reconstructions) - g_loss = -torch.mean(logits_fake) - - if self.training: - d_weight = self._adaptive_weight( - rec_loss=rec_loss, - g_loss=g_loss, - decoder_last_layer=decoder_last_layer, - ) - else: - d_weight = torch.tensor(0.0) - - d_factor = _adopt_weight( - self.discriminator_factor, - global_step=global_step, - threshold=self.discriminator_iter_start, - ) - - loss: Tensor = ( - rec_loss - + d_factor * d_weight * g_loss - + self.commitment_weight * commitment_loss - ) - log = { - f"{stage}/total_loss": loss, - f"{stage}/commitment_loss": commitment_loss, - f"{stage}/rec_loss": rec_loss, - f"{stage}/g_loss": g_loss, - } - return loss, log - - if optimizer_idx == 1: - logits_fake = self.discriminator(reconstructions.detach()) - logits_real = self.discriminator(data.detach()) - - d_factor = _adopt_weight( - self.discriminator_factor, - global_step=global_step, - threshold=self.discriminator_iter_start, - ) - - d_loss = d_factor * self.adversarial_loss( - logits_real=logits_real, logits_fake=logits_fake - ) - - log = { - f"{stage}/d_loss": d_loss, - } - return d_loss, log |