diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-27 22:41:39 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-27 22:41:39 +0200 |
commit | bec4aafe707be8e5763ad6b2194d4589f20594a9 (patch) | |
tree | 506517ca6a17241a305114e787d1b899a48a3d86 /text_recognizer/criterion | |
parent | 9a8044f4a3826a119416665741b709cd686fca87 (diff) |
Rename to criterion
Diffstat (limited to 'text_recognizer/criterion')
-rw-r--r-- | text_recognizer/criterion/__init__.py | 1 | ||||
-rw-r--r-- | text_recognizer/criterion/ctc.py | 38 | ||||
-rw-r--r-- | text_recognizer/criterion/label_smoothing.py | 50 | ||||
-rw-r--r-- | text_recognizer/criterion/n_layer_discriminator.py | 59 | ||||
-rw-r--r-- | text_recognizer/criterion/vqgan_loss.py | 123 |
5 files changed, 271 insertions, 0 deletions
diff --git a/text_recognizer/criterion/__init__.py b/text_recognizer/criterion/__init__.py new file mode 100644 index 0000000..5b0a7ab --- /dev/null +++ b/text_recognizer/criterion/__init__.py @@ -0,0 +1 @@ +"""Module with custom loss functions.""" diff --git a/text_recognizer/criterion/ctc.py b/text_recognizer/criterion/ctc.py new file mode 100644 index 0000000..42a0b25 --- /dev/null +++ b/text_recognizer/criterion/ctc.py @@ -0,0 +1,38 @@ +"""CTC loss.""" +import torch +from torch import LongTensor, nn, Tensor +import torch.nn.functional as F + + +class CTCLoss(nn.Module): + """CTC loss.""" + + def __init__(self, blank: int) -> None: + super().__init__() + self.blank = blank + + def forward(self, outputs: Tensor, targets: Tensor) -> Tensor: + """Computes the CTC loss.""" + device = outputs.device + + log_probs = F.log_softmax(outputs, dim=2).permute(1, 0, 2) + output_lengths = LongTensor([outputs.shape[1]] * outputs.shape[0]).to(device) + + targets_ = LongTensor([]).to(device) + target_lengths = LongTensor([]).to(device) + for target in targets: + # Remove padding + target = target[target != self.blank].to(device) + targets_ = torch.cat((targets_, target)) + target_lengths = torch.cat( + (target_lengths, torch.LongTensor([len(target)]).to(device)), dim=0 + ) + + return F.ctc_loss( + log_probs, + targets, + output_lengths, + target_lengths, + blank=self.blank, + zero_infinity=True, + ) diff --git a/text_recognizer/criterion/label_smoothing.py b/text_recognizer/criterion/label_smoothing.py new file mode 100644 index 0000000..5b3a47e --- /dev/null +++ b/text_recognizer/criterion/label_smoothing.py @@ -0,0 +1,50 @@ +"""Implementations of custom loss functions.""" +import torch +from torch import nn +from torch import Tensor + + +class LabelSmoothingLoss(nn.Module): + r"""Loss functions for making networks less over confident. + + It is used to calibrate the network so that the predicted probabilities + reflect the accuracy. The function is given by: + + L = (1 - \alpha) * y_hot + \alpha / K + + This means that some of the probability mass is transferred to the incorrect + labels, thus not forcing the network try to put all probability mass into + one label, and this works as a regulizer. + """ + + def __init__( + self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1 + ) -> None: + super().__init__() + if not 0.0 < smoothing < 1.0: + raise ValueError("Smoothing must be between 0.0 and 1.0") + self.ignore_index = ignore_index + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.dim = dim + + def forward(self, output: Tensor, target: Tensor) -> Tensor: + """Computes the loss. + + Args: + output (Tensor): outputictions from the network. + target (Tensor): Ground truth. + + Shapes: + TBC + + Returns: + (Tensor): Label smoothing loss. + """ + output = output.log_softmax(dim=self.dim) + with torch.no_grad(): + true_dist = torch.zeros_like(output) + true_dist.scatter_(1, target.unsqueeze(1), self.confidence) + true_dist.masked_fill_((target == 4).unsqueeze(1), 0) + true_dist += self.smoothing / output.size(self.dim) + return torch.mean(torch.sum(-true_dist * output, dim=self.dim)) diff --git a/text_recognizer/criterion/n_layer_discriminator.py b/text_recognizer/criterion/n_layer_discriminator.py new file mode 100644 index 0000000..a9f47f9 --- /dev/null +++ b/text_recognizer/criterion/n_layer_discriminator.py @@ -0,0 +1,59 @@ +"""Pix2pix discriminator loss.""" +from torch import nn, Tensor + +from text_recognizer.networks.vqvae.norm import Normalize + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator loss in Pix2Pix.""" + + def __init__( + self, in_channels: int = 1, num_channels: int = 32, num_layers: int = 3 + ) -> None: + super().__init__() + self.in_channels = in_channels + self.num_channels = num_channels + self.num_layers = num_layers + self.discriminator = self._build_discriminator() + + def _build_discriminator(self) -> nn.Sequential: + """Builds discriminator.""" + discriminator = [ + nn.Sigmoid(), + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.num_channels, + kernel_size=4, + stride=2, + padding=1, + ), + nn.Mish(inplace=True), + ] + in_channels = self.num_channels + for n in range(1, self.num_layers): + discriminator += [ + nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels * n, + kernel_size=4, + stride=2, + padding=1, + ), + # Normalize(num_channels=in_channels * n), + nn.Mish(inplace=True), + ] + in_channels *= n + + discriminator += [ + nn.Conv2d( + in_channels=self.num_channels * (self.num_layers - 1), + out_channels=1, + kernel_size=4, + padding=1, + ) + ] + return nn.Sequential(*discriminator) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass through discriminator.""" + return self.discriminator(x) diff --git a/text_recognizer/criterion/vqgan_loss.py b/text_recognizer/criterion/vqgan_loss.py new file mode 100644 index 0000000..9d1cddd --- /dev/null +++ b/text_recognizer/criterion/vqgan_loss.py @@ -0,0 +1,123 @@ +"""VQGAN loss for PyTorch Lightning.""" +from typing import Optional, Tuple + +import torch +from torch import nn, Tensor +import torch.nn.functional as F + +from text_recognizer.criterions.n_layer_discriminator import NLayerDiscriminator + + +def _adopt_weight( + weight: Tensor, global_step: int, threshold: int = 0, value: float = 0.0 +) -> float: + """Sets weight to value after the threshold is passed.""" + if global_step < threshold: + weight = value + return weight + + +class VQGANLoss(nn.Module): + """VQGAN loss.""" + + def __init__( + self, + reconstruction_loss: nn.L1Loss, + discriminator: NLayerDiscriminator, + commitment_weight: float = 1.0, + discriminator_weight: float = 1.0, + discriminator_factor: float = 1.0, + discriminator_iter_start: int = 1000, + ) -> None: + super().__init__() + self.reconstruction_loss = reconstruction_loss + self.discriminator = discriminator + self.commitment_weight = commitment_weight + self.discriminator_weight = discriminator_weight + self.discriminator_factor = discriminator_factor + self.discriminator_iter_start = discriminator_iter_start + + @staticmethod + def adversarial_loss(logits_real: Tensor, logits_fake: Tensor) -> Tensor: + """Calculates the adversarial loss.""" + loss_real = torch.mean(F.relu(1.0 - logits_real)) + loss_fake = torch.mean(F.relu(1.0 + logits_fake)) + d_loss = (loss_real + loss_fake) / 2.0 + return d_loss + + def _adaptive_weight( + self, rec_loss: Tensor, g_loss: Tensor, decoder_last_layer: Tensor + ) -> Tensor: + rec_grads = torch.autograd.grad( + rec_loss, decoder_last_layer, retain_graph=True + )[0] + g_grads = torch.autograd.grad(g_loss, decoder_last_layer, retain_graph=True)[0] + d_weight = torch.norm(rec_grads) / (torch.norm(g_grads) + 1.0e-4) + d_weight = torch.clamp(d_weight, 0.0, 1.0e4).detach() + d_weight *= self.discriminator_weight + return d_weight + + def forward( + self, + data: Tensor, + reconstructions: Tensor, + commitment_loss: Tensor, + decoder_last_layer: Tensor, + optimizer_idx: int, + global_step: int, + stage: str, + ) -> Optional[Tuple]: + """Calculates the VQGAN loss.""" + rec_loss: Tensor = self.reconstruction_loss(reconstructions, data) + + # GAN part. + if optimizer_idx == 0: + logits_fake = self.discriminator(reconstructions) + g_loss = -torch.mean(logits_fake) + + if self.training: + d_weight = self._adaptive_weight( + rec_loss=rec_loss, + g_loss=g_loss, + decoder_last_layer=decoder_last_layer, + ) + else: + d_weight = torch.tensor(0.0) + + d_factor = _adopt_weight( + self.discriminator_factor, + global_step=global_step, + threshold=self.discriminator_iter_start, + ) + + loss: Tensor = ( + rec_loss + + d_factor * d_weight * g_loss + + self.commitment_weight * commitment_loss + ) + log = { + f"{stage}/total_loss": loss, + f"{stage}/commitment_loss": commitment_loss, + f"{stage}/rec_loss": rec_loss, + f"{stage}/g_loss": g_loss, + } + return loss, log + + if optimizer_idx == 1: + logits_fake = self.discriminator(reconstructions.detach()) + logits_real = self.discriminator(data.detach()) + + d_factor = _adopt_weight( + self.discriminator_factor, + global_step=global_step, + threshold=self.discriminator_iter_start, + ) + + d_loss = d_factor * self.adversarial_loss( + logits_real=logits_real, logits_fake=logits_fake + ) + + log = { + f"{stage}/d_loss": d_loss, + } + return d_loss, log |