From bec4aafe707be8e5763ad6b2194d4589f20594a9 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 27 Oct 2021 22:41:39 +0200 Subject: Rename to criterion --- text_recognizer/criterion/__init__.py | 1 + text_recognizer/criterion/ctc.py | 38 +++ text_recognizer/criterion/label_smoothing.py | 50 +++ text_recognizer/criterion/n_layer_discriminator.py | 59 ++++ text_recognizer/criterion/vqgan_loss.py | 123 +++++++ text_recognizer/criterions/__init__.py | 1 - text_recognizer/criterions/ctc.py | 38 --- text_recognizer/criterions/label_smoothing.py | 50 --- .../criterions/n_layer_discriminator.py | 59 ---- text_recognizer/criterions/transducer.py | 376 --------------------- text_recognizer/criterions/vqgan_loss.py | 123 ------- 11 files changed, 271 insertions(+), 647 deletions(-) create mode 100644 text_recognizer/criterion/__init__.py create mode 100644 text_recognizer/criterion/ctc.py create mode 100644 text_recognizer/criterion/label_smoothing.py create mode 100644 text_recognizer/criterion/n_layer_discriminator.py create mode 100644 text_recognizer/criterion/vqgan_loss.py delete mode 100644 text_recognizer/criterions/__init__.py delete mode 100644 text_recognizer/criterions/ctc.py delete mode 100644 text_recognizer/criterions/label_smoothing.py delete mode 100644 text_recognizer/criterions/n_layer_discriminator.py delete mode 100644 text_recognizer/criterions/transducer.py delete mode 100644 text_recognizer/criterions/vqgan_loss.py (limited to 'text_recognizer') diff --git a/text_recognizer/criterion/__init__.py b/text_recognizer/criterion/__init__.py new file mode 100644 index 0000000..5b0a7ab --- /dev/null +++ b/text_recognizer/criterion/__init__.py @@ -0,0 +1 @@ +"""Module with custom loss functions.""" diff --git a/text_recognizer/criterion/ctc.py b/text_recognizer/criterion/ctc.py new file mode 100644 index 0000000..42a0b25 --- /dev/null +++ b/text_recognizer/criterion/ctc.py @@ -0,0 +1,38 @@ +"""CTC loss.""" +import torch +from torch import LongTensor, nn, Tensor +import torch.nn.functional as F + + +class CTCLoss(nn.Module): + """CTC loss.""" + + def __init__(self, blank: int) -> None: + super().__init__() + self.blank = blank + + def forward(self, outputs: Tensor, targets: Tensor) -> Tensor: + """Computes the CTC loss.""" + device = outputs.device + + log_probs = F.log_softmax(outputs, dim=2).permute(1, 0, 2) + output_lengths = LongTensor([outputs.shape[1]] * outputs.shape[0]).to(device) + + targets_ = LongTensor([]).to(device) + target_lengths = LongTensor([]).to(device) + for target in targets: + # Remove padding + target = target[target != self.blank].to(device) + targets_ = torch.cat((targets_, target)) + target_lengths = torch.cat( + (target_lengths, torch.LongTensor([len(target)]).to(device)), dim=0 + ) + + return F.ctc_loss( + log_probs, + targets, + output_lengths, + target_lengths, + blank=self.blank, + zero_infinity=True, + ) diff --git a/text_recognizer/criterion/label_smoothing.py b/text_recognizer/criterion/label_smoothing.py new file mode 100644 index 0000000..5b3a47e --- /dev/null +++ b/text_recognizer/criterion/label_smoothing.py @@ -0,0 +1,50 @@ +"""Implementations of custom loss functions.""" +import torch +from torch import nn +from torch import Tensor + + +class LabelSmoothingLoss(nn.Module): + r"""Loss functions for making networks less over confident. + + It is used to calibrate the network so that the predicted probabilities + reflect the accuracy. The function is given by: + + L = (1 - \alpha) * y_hot + \alpha / K + + This means that some of the probability mass is transferred to the incorrect + labels, thus not forcing the network try to put all probability mass into + one label, and this works as a regulizer. + """ + + def __init__( + self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1 + ) -> None: + super().__init__() + if not 0.0 < smoothing < 1.0: + raise ValueError("Smoothing must be between 0.0 and 1.0") + self.ignore_index = ignore_index + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.dim = dim + + def forward(self, output: Tensor, target: Tensor) -> Tensor: + """Computes the loss. + + Args: + output (Tensor): outputictions from the network. + target (Tensor): Ground truth. + + Shapes: + TBC + + Returns: + (Tensor): Label smoothing loss. + """ + output = output.log_softmax(dim=self.dim) + with torch.no_grad(): + true_dist = torch.zeros_like(output) + true_dist.scatter_(1, target.unsqueeze(1), self.confidence) + true_dist.masked_fill_((target == 4).unsqueeze(1), 0) + true_dist += self.smoothing / output.size(self.dim) + return torch.mean(torch.sum(-true_dist * output, dim=self.dim)) diff --git a/text_recognizer/criterion/n_layer_discriminator.py b/text_recognizer/criterion/n_layer_discriminator.py new file mode 100644 index 0000000..a9f47f9 --- /dev/null +++ b/text_recognizer/criterion/n_layer_discriminator.py @@ -0,0 +1,59 @@ +"""Pix2pix discriminator loss.""" +from torch import nn, Tensor + +from text_recognizer.networks.vqvae.norm import Normalize + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator loss in Pix2Pix.""" + + def __init__( + self, in_channels: int = 1, num_channels: int = 32, num_layers: int = 3 + ) -> None: + super().__init__() + self.in_channels = in_channels + self.num_channels = num_channels + self.num_layers = num_layers + self.discriminator = self._build_discriminator() + + def _build_discriminator(self) -> nn.Sequential: + """Builds discriminator.""" + discriminator = [ + nn.Sigmoid(), + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.num_channels, + kernel_size=4, + stride=2, + padding=1, + ), + nn.Mish(inplace=True), + ] + in_channels = self.num_channels + for n in range(1, self.num_layers): + discriminator += [ + nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels * n, + kernel_size=4, + stride=2, + padding=1, + ), + # Normalize(num_channels=in_channels * n), + nn.Mish(inplace=True), + ] + in_channels *= n + + discriminator += [ + nn.Conv2d( + in_channels=self.num_channels * (self.num_layers - 1), + out_channels=1, + kernel_size=4, + padding=1, + ) + ] + return nn.Sequential(*discriminator) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass through discriminator.""" + return self.discriminator(x) diff --git a/text_recognizer/criterion/vqgan_loss.py b/text_recognizer/criterion/vqgan_loss.py new file mode 100644 index 0000000..9d1cddd --- /dev/null +++ b/text_recognizer/criterion/vqgan_loss.py @@ -0,0 +1,123 @@ +"""VQGAN loss for PyTorch Lightning.""" +from typing import Optional, Tuple + +import torch +from torch import nn, Tensor +import torch.nn.functional as F + +from text_recognizer.criterions.n_layer_discriminator import NLayerDiscriminator + + +def _adopt_weight( + weight: Tensor, global_step: int, threshold: int = 0, value: float = 0.0 +) -> float: + """Sets weight to value after the threshold is passed.""" + if global_step < threshold: + weight = value + return weight + + +class VQGANLoss(nn.Module): + """VQGAN loss.""" + + def __init__( + self, + reconstruction_loss: nn.L1Loss, + discriminator: NLayerDiscriminator, + commitment_weight: float = 1.0, + discriminator_weight: float = 1.0, + discriminator_factor: float = 1.0, + discriminator_iter_start: int = 1000, + ) -> None: + super().__init__() + self.reconstruction_loss = reconstruction_loss + self.discriminator = discriminator + self.commitment_weight = commitment_weight + self.discriminator_weight = discriminator_weight + self.discriminator_factor = discriminator_factor + self.discriminator_iter_start = discriminator_iter_start + + @staticmethod + def adversarial_loss(logits_real: Tensor, logits_fake: Tensor) -> Tensor: + """Calculates the adversarial loss.""" + loss_real = torch.mean(F.relu(1.0 - logits_real)) + loss_fake = torch.mean(F.relu(1.0 + logits_fake)) + d_loss = (loss_real + loss_fake) / 2.0 + return d_loss + + def _adaptive_weight( + self, rec_loss: Tensor, g_loss: Tensor, decoder_last_layer: Tensor + ) -> Tensor: + rec_grads = torch.autograd.grad( + rec_loss, decoder_last_layer, retain_graph=True + )[0] + g_grads = torch.autograd.grad(g_loss, decoder_last_layer, retain_graph=True)[0] + d_weight = torch.norm(rec_grads) / (torch.norm(g_grads) + 1.0e-4) + d_weight = torch.clamp(d_weight, 0.0, 1.0e4).detach() + d_weight *= self.discriminator_weight + return d_weight + + def forward( + self, + data: Tensor, + reconstructions: Tensor, + commitment_loss: Tensor, + decoder_last_layer: Tensor, + optimizer_idx: int, + global_step: int, + stage: str, + ) -> Optional[Tuple]: + """Calculates the VQGAN loss.""" + rec_loss: Tensor = self.reconstruction_loss(reconstructions, data) + + # GAN part. + if optimizer_idx == 0: + logits_fake = self.discriminator(reconstructions) + g_loss = -torch.mean(logits_fake) + + if self.training: + d_weight = self._adaptive_weight( + rec_loss=rec_loss, + g_loss=g_loss, + decoder_last_layer=decoder_last_layer, + ) + else: + d_weight = torch.tensor(0.0) + + d_factor = _adopt_weight( + self.discriminator_factor, + global_step=global_step, + threshold=self.discriminator_iter_start, + ) + + loss: Tensor = ( + rec_loss + + d_factor * d_weight * g_loss + + self.commitment_weight * commitment_loss + ) + log = { + f"{stage}/total_loss": loss, + f"{stage}/commitment_loss": commitment_loss, + f"{stage}/rec_loss": rec_loss, + f"{stage}/g_loss": g_loss, + } + return loss, log + + if optimizer_idx == 1: + logits_fake = self.discriminator(reconstructions.detach()) + logits_real = self.discriminator(data.detach()) + + d_factor = _adopt_weight( + self.discriminator_factor, + global_step=global_step, + threshold=self.discriminator_iter_start, + ) + + d_loss = d_factor * self.adversarial_loss( + logits_real=logits_real, logits_fake=logits_fake + ) + + log = { + f"{stage}/d_loss": d_loss, + } + return d_loss, log diff --git a/text_recognizer/criterions/__init__.py b/text_recognizer/criterions/__init__.py deleted file mode 100644 index 5b0a7ab..0000000 --- a/text_recognizer/criterions/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Module with custom loss functions.""" diff --git a/text_recognizer/criterions/ctc.py b/text_recognizer/criterions/ctc.py deleted file mode 100644 index 42a0b25..0000000 --- a/text_recognizer/criterions/ctc.py +++ /dev/null @@ -1,38 +0,0 @@ -"""CTC loss.""" -import torch -from torch import LongTensor, nn, Tensor -import torch.nn.functional as F - - -class CTCLoss(nn.Module): - """CTC loss.""" - - def __init__(self, blank: int) -> None: - super().__init__() - self.blank = blank - - def forward(self, outputs: Tensor, targets: Tensor) -> Tensor: - """Computes the CTC loss.""" - device = outputs.device - - log_probs = F.log_softmax(outputs, dim=2).permute(1, 0, 2) - output_lengths = LongTensor([outputs.shape[1]] * outputs.shape[0]).to(device) - - targets_ = LongTensor([]).to(device) - target_lengths = LongTensor([]).to(device) - for target in targets: - # Remove padding - target = target[target != self.blank].to(device) - targets_ = torch.cat((targets_, target)) - target_lengths = torch.cat( - (target_lengths, torch.LongTensor([len(target)]).to(device)), dim=0 - ) - - return F.ctc_loss( - log_probs, - targets, - output_lengths, - target_lengths, - blank=self.blank, - zero_infinity=True, - ) diff --git a/text_recognizer/criterions/label_smoothing.py b/text_recognizer/criterions/label_smoothing.py deleted file mode 100644 index 5b3a47e..0000000 --- a/text_recognizer/criterions/label_smoothing.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Implementations of custom loss functions.""" -import torch -from torch import nn -from torch import Tensor - - -class LabelSmoothingLoss(nn.Module): - r"""Loss functions for making networks less over confident. - - It is used to calibrate the network so that the predicted probabilities - reflect the accuracy. The function is given by: - - L = (1 - \alpha) * y_hot + \alpha / K - - This means that some of the probability mass is transferred to the incorrect - labels, thus not forcing the network try to put all probability mass into - one label, and this works as a regulizer. - """ - - def __init__( - self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1 - ) -> None: - super().__init__() - if not 0.0 < smoothing < 1.0: - raise ValueError("Smoothing must be between 0.0 and 1.0") - self.ignore_index = ignore_index - self.confidence = 1.0 - smoothing - self.smoothing = smoothing - self.dim = dim - - def forward(self, output: Tensor, target: Tensor) -> Tensor: - """Computes the loss. - - Args: - output (Tensor): outputictions from the network. - target (Tensor): Ground truth. - - Shapes: - TBC - - Returns: - (Tensor): Label smoothing loss. - """ - output = output.log_softmax(dim=self.dim) - with torch.no_grad(): - true_dist = torch.zeros_like(output) - true_dist.scatter_(1, target.unsqueeze(1), self.confidence) - true_dist.masked_fill_((target == 4).unsqueeze(1), 0) - true_dist += self.smoothing / output.size(self.dim) - return torch.mean(torch.sum(-true_dist * output, dim=self.dim)) diff --git a/text_recognizer/criterions/n_layer_discriminator.py b/text_recognizer/criterions/n_layer_discriminator.py deleted file mode 100644 index a9f47f9..0000000 --- a/text_recognizer/criterions/n_layer_discriminator.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Pix2pix discriminator loss.""" -from torch import nn, Tensor - -from text_recognizer.networks.vqvae.norm import Normalize - - -class NLayerDiscriminator(nn.Module): - """Defines a PatchGAN discriminator loss in Pix2Pix.""" - - def __init__( - self, in_channels: int = 1, num_channels: int = 32, num_layers: int = 3 - ) -> None: - super().__init__() - self.in_channels = in_channels - self.num_channels = num_channels - self.num_layers = num_layers - self.discriminator = self._build_discriminator() - - def _build_discriminator(self) -> nn.Sequential: - """Builds discriminator.""" - discriminator = [ - nn.Sigmoid(), - nn.Conv2d( - in_channels=self.in_channels, - out_channels=self.num_channels, - kernel_size=4, - stride=2, - padding=1, - ), - nn.Mish(inplace=True), - ] - in_channels = self.num_channels - for n in range(1, self.num_layers): - discriminator += [ - nn.Conv2d( - in_channels=in_channels, - out_channels=in_channels * n, - kernel_size=4, - stride=2, - padding=1, - ), - # Normalize(num_channels=in_channels * n), - nn.Mish(inplace=True), - ] - in_channels *= n - - discriminator += [ - nn.Conv2d( - in_channels=self.num_channels * (self.num_layers - 1), - out_channels=1, - kernel_size=4, - padding=1, - ) - ] - return nn.Sequential(*discriminator) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass through discriminator.""" - return self.discriminator(x) diff --git a/text_recognizer/criterions/transducer.py b/text_recognizer/criterions/transducer.py deleted file mode 100644 index 089bff7..0000000 --- a/text_recognizer/criterions/transducer.py +++ /dev/null @@ -1,376 +0,0 @@ -"""Transducer and the transducer loss function.py - -Stolen from: - https://github.com/facebookresearch/gtn_applications/blob/master/transducer.py - -""" -from pathlib import Path -import itertools -from typing import Dict, List, Optional, Sequence, Set, Tuple - -import gtn -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.data.utils.iam_preprocessor import Preprocessor - - -def make_scalar_graph(weight) -> gtn.Graph: - scalar = gtn.Graph() - scalar.add_node(True) - scalar.add_node(False, True) - scalar.add_arc(0, 1, 0, 0, weight) - return scalar - - -def make_chain_graph(sequence) -> gtn.Graph: - graph = gtn.Graph(False) - graph.add_node(True) - for i, s in enumerate(sequence): - graph.add_node(False, i == (len(sequence) - 1)) - graph.add_arc(i, i + 1, s) - return graph - - -def make_transitions_graph( - ngram: int, num_tokens: int, calc_grad: bool = False -) -> gtn.Graph: - transitions = gtn.Graph(calc_grad) - transitions.add_node(True, ngram == 1) - - state_map = {(): 0} - - # First build transitions which include : - for n in range(1, ngram): - for state in itertools.product(range(num_tokens), repeat=n): - in_idx = state_map[state[:-1]] - out_idx = transitions.add_node(False, ngram == 1) - state_map[state] = out_idx - transitions.add_arc(in_idx, out_idx, state[-1]) - - for state in itertools.product(range(num_tokens), repeat=ngram): - state_idx = state_map[state[:-1]] - new_state_idx = state_map[state[1:]] - # p(state[-1] | state[:-1]) - transitions.add_arc(state_idx, new_state_idx, state[-1]) - - if ngram > 1: - # Build transitions which include : - end_idx = transitions.add_node(False, True) - for in_idx in range(end_idx): - transitions.add_arc(in_idx, end_idx, gtn.epsilon) - - return transitions - - -def make_lexicon_graph( - word_pieces: List, graphemes_to_idx: Dict, special_tokens: Optional[Set] -) -> gtn.Graph: - """Constructs a graph which transduces letters to word pieces.""" - graph = gtn.Graph(False) - graph.add_node(True, True) - for i, wp in enumerate(word_pieces): - prev = 0 - if special_tokens is not None and wp in special_tokens: - n = graph.add_node() - graph.add_arc(prev, n, graphemes_to_idx[wp], i) - else: - for character in wp[:-1]: - n = graph.add_node() - graph.add_arc(prev, n, graphemes_to_idx[character], gtn.epsilon) - prev = n - graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i) - graph.arc_sort() - return graph - - -def make_token_graph( - token_list: List, blank: str = "none", allow_repeats: bool = True -) -> gtn.Graph: - """Constructs a graph with all the individual token transition models.""" - if not allow_repeats and blank != "optional": - raise ValueError("Must use blank='optional' if disallowing repeats.") - - ntoks = len(token_list) - graph = gtn.Graph(False) - - # Creating nodes - graph.add_node(True, True) - for i in range(ntoks): - # We can consume one or more consecutive word - # pieces for each emission: - # E.g. [ab, ab, ab] transduces to [ab] - graph.add_node(False, blank != "forced") - - if blank != "none": - graph.add_node() - - # Creating arcs - if blank != "none": - # Blank index is assumed to be last (ntoks) - graph.add_arc(0, ntoks + 1, ntoks, gtn.epsilon) - graph.add_arc(ntoks + 1, 0, gtn.epsilon) - - for i in range(ntoks): - graph.add_arc((ntoks + 1) if blank == "forced" else 0, i + 1, i) - graph.add_arc(i + 1, i + 1, i, gtn.epsilon) - - if allow_repeats: - if blank == "forced": - # Allow transitions from token to blank only - graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon) - else: - # Allow transition from token to blank and all other tokens - graph.add_arc(i + 1, 0, gtn.epsilon) - - else: - # allow transitions to blank and all other tokens except the same token - graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon) - for j in range(ntoks): - if i != j: - graph.add_arc(i + 1, j + 1, j, j) - - return graph - - -class TransducerLossFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - inputs, - targets, - tokens, - lexicon, - transition_params=None, - transitions=None, - reduction="none", - ) -> Tensor: - B, T, C = inputs.shape - - losses = [None] * B - emissions_graphs = [None] * B - - if transitions is not None: - if transition_params is None: - raise ValueError("Specified transitions, but not transition params.") - - cpu_data = transition_params.cpu().contiguous() - transitions.set_weights(cpu_data.data_ptr()) - transitions.calc_grad = transition_params.requires_grad - transitions.zero_grad() - - def process(b: int) -> None: - # Create emission graph: - emissions = gtn.linear_graph(T, C, inputs.requires_grad) - cpu_data = inputs[b].cpu().contiguous() - emissions.set_weights(cpu_data.data_ptr()) - target = make_chain_graph(targets[b]) - target.arc_sort(True) - - # Create token tot grapheme decomposition graph - tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon))) - tokens_target.arc_sort() - - # Create alignment graph: - aligments = gtn.project_input( - gtn.remove(gtn.compose(tokens, tokens_target)) - ) - aligments.arc_sort() - - # Add transitions scores: - if transitions is not None: - aligments = gtn.intersect(transitions, aligments) - aligments.arc_sort() - - loss = gtn.forward_score(gtn.intersect(emissions, aligments)) - - # Normalize if needed: - if transitions is not None: - norm = gtn.forward_score(gtn.intersect(emissions, transitions)) - loss = gtn.subtract(loss, norm) - - losses[b] = gtn.negate(loss) - - # Save for backward: - if emissions.calc_grad: - emissions_graphs[b] = emissions - - gtn.parallel_for(process, range(B)) - - ctx.graphs = (losses, emissions_graphs, transitions) - ctx.input_shape = inputs.shape - - # Optionally reduce by target length - if reduction == "mean": - scales = [(1 / len(t) if len(t) > 0 else 1.0) for t in targets] - else: - scales = [1.0] * B - - ctx.scales = scales - - loss = torch.tensor([l.item() * s for l, s in zip(losses, scales)]) - return torch.mean(loss.to(inputs.device)) - - @staticmethod - def backward(ctx, grad_output) -> Tuple: - losses, emissions_graphs, transitions = ctx.graphs - scales = ctx.scales - - B, T, C = ctx.input_shape - calc_emissions = ctx.needs_input_grad[0] - input_grad = torch.empty((B, T, C)) if calc_emissions else None - - def process(b: int) -> None: - scale = make_scalar_graph(scales[b]) - gtn.backward(losses[b], scale) - emissions = emissions_graphs[b] - if calc_emissions: - grad = emissions.grad().weights_to_numpy() - input_grad[b] = torch.tensor(grad).view(1, T, C) - - gtn.parallel_for(process, range(B)) - - if calc_emissions: - input_grad = input_grad.to(grad_output.device) - input_grad *= grad_output / B - - if ctx.needs_input_grad[4]: - grad = transitions.grad().weights_to_numpy() - transition_grad = torch.tensor(grad).to(grad_output.device) - transition_grad *= grad_output / B - else: - transition_grad = None - - return ( - input_grad, - None, # target - None, # tokens - None, # lexicon - transition_grad, # transition params - None, # transitions graph - None, - ) - - -TransducerLoss = TransducerLossFunction.apply - - -class Transducer(nn.Module): - def __init__( - self, - preprocessor: Preprocessor, - ngram: int = 0, - transitions: str = None, - blank: str = "none", - allow_repeats: bool = True, - reduction: str = "none", - ) -> None: - """A generic transducer loss function. - - Args: - preprocessor (Preprocessor) : The IAM preprocessor for word pieces. - ngram (int) : Order of the token-level transition model. If `ngram=0` - then no transition model is used. - blank (string) : Specifies the usage of blank token - 'none' - do not use blank token - 'optional' - allow an optional blank inbetween tokens - 'forced' - force a blank inbetween tokens (also referred to as garbage token) - allow_repeats (boolean) : If false, then we don't allow paths with - consecutive tokens in the alignment graph. This keeps the graph - unambiguous in the sense that the same input cannot transduce to - different outputs. - """ - super().__init__() - if blank not in ["optional", "forced", "none"]: - raise ValueError( - "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']" - ) - self.tokens = make_token_graph( - preprocessor.tokens, blank=blank, allow_repeats=allow_repeats - ) - self.lexicon = make_lexicon_graph( - preprocessor.tokens, - preprocessor.graphemes_to_index, - preprocessor.special_tokens, - ) - self.ngram = ngram - - self.transitions: Optional[gtn.Graph] = None - self.transitions_params: Optional[nn.Parameter] = None - self._load_transitions(transitions, preprocessor, blank) - - if ngram > 0 and transitions is not None: - raise ValueError("Only one of ngram and transitions may be specified") - - self.reduction = reduction - - def _load_transitions( - self, transitions: Optional[str], preprocessor: Preprocessor, blank: str - ): - """Loads transition graph.""" - processed_path = ( - Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines" - ) - if transitions is not None: - transitions = gtn.load(str(processed_path / transitions)) - if self.ngram > 0: - self.transitions = make_transitions_graph( - self.ngram, len(preprocessor.tokens) + int(blank != "none"), True - ) - if transitions is not None: - self.transitions = transitions - self.transitions.arc_sort() - self.transitions_params = nn.Parameter( - torch.zeros(self.transitions.num_arcs()) - ) - - def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss: - return TransducerLoss( - inputs, - targets, - self.tokens, - self.lexicon, - self.transitions_params, - self.transitions, - self.reduction, - ) - - def viterbi(self, outputs: Tensor) -> List[Tensor]: - B, T, C = outputs.shape - - if self.transitions is not None: - cpu_data = self.transition_params.cpu().contiguous() - self.transitions.set_weights(cpu_data.data_ptr()) - self.transitions.calc_grad = False - - self.tokens.arc_sort() - - paths = [None] * B - - def process(b: int) -> None: - emissions = gtn.linear_graph(T, C, False) - cpu_data = outputs[b].cpu().contiguous() - emissions.set_weights(cpu_data.data_ptr()) - - if self.transitions is not None: - full_graph = gtn.intersect(emissions, self.transitions) - else: - full_graph = emissions - - # Find the best path and remove back-off arcs: - path = gtn.remove(gtn.viterbi_path(full_graph)) - - # Left compose the viterbi path with the "aligment to token" - # transducer to get the outputs: - path = gtn.compose(path, self.tokens) - - # When there are ambiguous paths (allow_repeats is true), we take - # the shortest: - path = gtn.viterbi_path(path) - path = gtn.remove(gtn.project_output(path)) - paths[b] = path.labels_to_list() - - gtn.parallel_for(process, range(B)) - predictions = [torch.IntTensor(path) for path in paths] - return predictions diff --git a/text_recognizer/criterions/vqgan_loss.py b/text_recognizer/criterions/vqgan_loss.py deleted file mode 100644 index 9d1cddd..0000000 --- a/text_recognizer/criterions/vqgan_loss.py +++ /dev/null @@ -1,123 +0,0 @@ -"""VQGAN loss for PyTorch Lightning.""" -from typing import Optional, Tuple - -import torch -from torch import nn, Tensor -import torch.nn.functional as F - -from text_recognizer.criterions.n_layer_discriminator import NLayerDiscriminator - - -def _adopt_weight( - weight: Tensor, global_step: int, threshold: int = 0, value: float = 0.0 -) -> float: - """Sets weight to value after the threshold is passed.""" - if global_step < threshold: - weight = value - return weight - - -class VQGANLoss(nn.Module): - """VQGAN loss.""" - - def __init__( - self, - reconstruction_loss: nn.L1Loss, - discriminator: NLayerDiscriminator, - commitment_weight: float = 1.0, - discriminator_weight: float = 1.0, - discriminator_factor: float = 1.0, - discriminator_iter_start: int = 1000, - ) -> None: - super().__init__() - self.reconstruction_loss = reconstruction_loss - self.discriminator = discriminator - self.commitment_weight = commitment_weight - self.discriminator_weight = discriminator_weight - self.discriminator_factor = discriminator_factor - self.discriminator_iter_start = discriminator_iter_start - - @staticmethod - def adversarial_loss(logits_real: Tensor, logits_fake: Tensor) -> Tensor: - """Calculates the adversarial loss.""" - loss_real = torch.mean(F.relu(1.0 - logits_real)) - loss_fake = torch.mean(F.relu(1.0 + logits_fake)) - d_loss = (loss_real + loss_fake) / 2.0 - return d_loss - - def _adaptive_weight( - self, rec_loss: Tensor, g_loss: Tensor, decoder_last_layer: Tensor - ) -> Tensor: - rec_grads = torch.autograd.grad( - rec_loss, decoder_last_layer, retain_graph=True - )[0] - g_grads = torch.autograd.grad(g_loss, decoder_last_layer, retain_graph=True)[0] - d_weight = torch.norm(rec_grads) / (torch.norm(g_grads) + 1.0e-4) - d_weight = torch.clamp(d_weight, 0.0, 1.0e4).detach() - d_weight *= self.discriminator_weight - return d_weight - - def forward( - self, - data: Tensor, - reconstructions: Tensor, - commitment_loss: Tensor, - decoder_last_layer: Tensor, - optimizer_idx: int, - global_step: int, - stage: str, - ) -> Optional[Tuple]: - """Calculates the VQGAN loss.""" - rec_loss: Tensor = self.reconstruction_loss(reconstructions, data) - - # GAN part. - if optimizer_idx == 0: - logits_fake = self.discriminator(reconstructions) - g_loss = -torch.mean(logits_fake) - - if self.training: - d_weight = self._adaptive_weight( - rec_loss=rec_loss, - g_loss=g_loss, - decoder_last_layer=decoder_last_layer, - ) - else: - d_weight = torch.tensor(0.0) - - d_factor = _adopt_weight( - self.discriminator_factor, - global_step=global_step, - threshold=self.discriminator_iter_start, - ) - - loss: Tensor = ( - rec_loss - + d_factor * d_weight * g_loss - + self.commitment_weight * commitment_loss - ) - log = { - f"{stage}/total_loss": loss, - f"{stage}/commitment_loss": commitment_loss, - f"{stage}/rec_loss": rec_loss, - f"{stage}/g_loss": g_loss, - } - return loss, log - - if optimizer_idx == 1: - logits_fake = self.discriminator(reconstructions.detach()) - logits_real = self.discriminator(data.detach()) - - d_factor = _adopt_weight( - self.discriminator_factor, - global_step=global_step, - threshold=self.discriminator_iter_start, - ) - - d_loss = d_factor * self.adversarial_loss( - logits_real=logits_real, logits_fake=logits_fake - ) - - log = { - f"{stage}/d_loss": d_loss, - } - return d_loss, log -- cgit v1.2.3-70-g09d2