summaryrefslogtreecommitdiff
path: root/text_recognizer/criterions
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/criterions')
-rw-r--r--text_recognizer/criterions/__init__.py1
-rw-r--r--text_recognizer/criterions/ctc.py38
-rw-r--r--text_recognizer/criterions/label_smoothing.py50
-rw-r--r--text_recognizer/criterions/n_layer_discriminator.py59
-rw-r--r--text_recognizer/criterions/transducer.py376
-rw-r--r--text_recognizer/criterions/vqgan_loss.py123
6 files changed, 0 insertions, 647 deletions
diff --git a/text_recognizer/criterions/__init__.py b/text_recognizer/criterions/__init__.py
deleted file mode 100644
index 5b0a7ab..0000000
--- a/text_recognizer/criterions/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Module with custom loss functions."""
diff --git a/text_recognizer/criterions/ctc.py b/text_recognizer/criterions/ctc.py
deleted file mode 100644
index 42a0b25..0000000
--- a/text_recognizer/criterions/ctc.py
+++ /dev/null
@@ -1,38 +0,0 @@
-"""CTC loss."""
-import torch
-from torch import LongTensor, nn, Tensor
-import torch.nn.functional as F
-
-
-class CTCLoss(nn.Module):
- """CTC loss."""
-
- def __init__(self, blank: int) -> None:
- super().__init__()
- self.blank = blank
-
- def forward(self, outputs: Tensor, targets: Tensor) -> Tensor:
- """Computes the CTC loss."""
- device = outputs.device
-
- log_probs = F.log_softmax(outputs, dim=2).permute(1, 0, 2)
- output_lengths = LongTensor([outputs.shape[1]] * outputs.shape[0]).to(device)
-
- targets_ = LongTensor([]).to(device)
- target_lengths = LongTensor([]).to(device)
- for target in targets:
- # Remove padding
- target = target[target != self.blank].to(device)
- targets_ = torch.cat((targets_, target))
- target_lengths = torch.cat(
- (target_lengths, torch.LongTensor([len(target)]).to(device)), dim=0
- )
-
- return F.ctc_loss(
- log_probs,
- targets,
- output_lengths,
- target_lengths,
- blank=self.blank,
- zero_infinity=True,
- )
diff --git a/text_recognizer/criterions/label_smoothing.py b/text_recognizer/criterions/label_smoothing.py
deleted file mode 100644
index 5b3a47e..0000000
--- a/text_recognizer/criterions/label_smoothing.py
+++ /dev/null
@@ -1,50 +0,0 @@
-"""Implementations of custom loss functions."""
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class LabelSmoothingLoss(nn.Module):
- r"""Loss functions for making networks less over confident.
-
- It is used to calibrate the network so that the predicted probabilities
- reflect the accuracy. The function is given by:
-
- L = (1 - \alpha) * y_hot + \alpha / K
-
- This means that some of the probability mass is transferred to the incorrect
- labels, thus not forcing the network try to put all probability mass into
- one label, and this works as a regulizer.
- """
-
- def __init__(
- self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1
- ) -> None:
- super().__init__()
- if not 0.0 < smoothing < 1.0:
- raise ValueError("Smoothing must be between 0.0 and 1.0")
- self.ignore_index = ignore_index
- self.confidence = 1.0 - smoothing
- self.smoothing = smoothing
- self.dim = dim
-
- def forward(self, output: Tensor, target: Tensor) -> Tensor:
- """Computes the loss.
-
- Args:
- output (Tensor): outputictions from the network.
- target (Tensor): Ground truth.
-
- Shapes:
- TBC
-
- Returns:
- (Tensor): Label smoothing loss.
- """
- output = output.log_softmax(dim=self.dim)
- with torch.no_grad():
- true_dist = torch.zeros_like(output)
- true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
- true_dist.masked_fill_((target == 4).unsqueeze(1), 0)
- true_dist += self.smoothing / output.size(self.dim)
- return torch.mean(torch.sum(-true_dist * output, dim=self.dim))
diff --git a/text_recognizer/criterions/n_layer_discriminator.py b/text_recognizer/criterions/n_layer_discriminator.py
deleted file mode 100644
index a9f47f9..0000000
--- a/text_recognizer/criterions/n_layer_discriminator.py
+++ /dev/null
@@ -1,59 +0,0 @@
-"""Pix2pix discriminator loss."""
-from torch import nn, Tensor
-
-from text_recognizer.networks.vqvae.norm import Normalize
-
-
-class NLayerDiscriminator(nn.Module):
- """Defines a PatchGAN discriminator loss in Pix2Pix."""
-
- def __init__(
- self, in_channels: int = 1, num_channels: int = 32, num_layers: int = 3
- ) -> None:
- super().__init__()
- self.in_channels = in_channels
- self.num_channels = num_channels
- self.num_layers = num_layers
- self.discriminator = self._build_discriminator()
-
- def _build_discriminator(self) -> nn.Sequential:
- """Builds discriminator."""
- discriminator = [
- nn.Sigmoid(),
- nn.Conv2d(
- in_channels=self.in_channels,
- out_channels=self.num_channels,
- kernel_size=4,
- stride=2,
- padding=1,
- ),
- nn.Mish(inplace=True),
- ]
- in_channels = self.num_channels
- for n in range(1, self.num_layers):
- discriminator += [
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=in_channels * n,
- kernel_size=4,
- stride=2,
- padding=1,
- ),
- # Normalize(num_channels=in_channels * n),
- nn.Mish(inplace=True),
- ]
- in_channels *= n
-
- discriminator += [
- nn.Conv2d(
- in_channels=self.num_channels * (self.num_layers - 1),
- out_channels=1,
- kernel_size=4,
- padding=1,
- )
- ]
- return nn.Sequential(*discriminator)
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass through discriminator."""
- return self.discriminator(x)
diff --git a/text_recognizer/criterions/transducer.py b/text_recognizer/criterions/transducer.py
deleted file mode 100644
index 089bff7..0000000
--- a/text_recognizer/criterions/transducer.py
+++ /dev/null
@@ -1,376 +0,0 @@
-"""Transducer and the transducer loss function.py
-
-Stolen from:
- https://github.com/facebookresearch/gtn_applications/blob/master/transducer.py
-
-"""
-from pathlib import Path
-import itertools
-from typing import Dict, List, Optional, Sequence, Set, Tuple
-
-import gtn
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.data.utils.iam_preprocessor import Preprocessor
-
-
-def make_scalar_graph(weight) -> gtn.Graph:
- scalar = gtn.Graph()
- scalar.add_node(True)
- scalar.add_node(False, True)
- scalar.add_arc(0, 1, 0, 0, weight)
- return scalar
-
-
-def make_chain_graph(sequence) -> gtn.Graph:
- graph = gtn.Graph(False)
- graph.add_node(True)
- for i, s in enumerate(sequence):
- graph.add_node(False, i == (len(sequence) - 1))
- graph.add_arc(i, i + 1, s)
- return graph
-
-
-def make_transitions_graph(
- ngram: int, num_tokens: int, calc_grad: bool = False
-) -> gtn.Graph:
- transitions = gtn.Graph(calc_grad)
- transitions.add_node(True, ngram == 1)
-
- state_map = {(): 0}
-
- # First build transitions which include <s>:
- for n in range(1, ngram):
- for state in itertools.product(range(num_tokens), repeat=n):
- in_idx = state_map[state[:-1]]
- out_idx = transitions.add_node(False, ngram == 1)
- state_map[state] = out_idx
- transitions.add_arc(in_idx, out_idx, state[-1])
-
- for state in itertools.product(range(num_tokens), repeat=ngram):
- state_idx = state_map[state[:-1]]
- new_state_idx = state_map[state[1:]]
- # p(state[-1] | state[:-1])
- transitions.add_arc(state_idx, new_state_idx, state[-1])
-
- if ngram > 1:
- # Build transitions which include </s>:
- end_idx = transitions.add_node(False, True)
- for in_idx in range(end_idx):
- transitions.add_arc(in_idx, end_idx, gtn.epsilon)
-
- return transitions
-
-
-def make_lexicon_graph(
- word_pieces: List, graphemes_to_idx: Dict, special_tokens: Optional[Set]
-) -> gtn.Graph:
- """Constructs a graph which transduces letters to word pieces."""
- graph = gtn.Graph(False)
- graph.add_node(True, True)
- for i, wp in enumerate(word_pieces):
- prev = 0
- if special_tokens is not None and wp in special_tokens:
- n = graph.add_node()
- graph.add_arc(prev, n, graphemes_to_idx[wp], i)
- else:
- for character in wp[:-1]:
- n = graph.add_node()
- graph.add_arc(prev, n, graphemes_to_idx[character], gtn.epsilon)
- prev = n
- graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i)
- graph.arc_sort()
- return graph
-
-
-def make_token_graph(
- token_list: List, blank: str = "none", allow_repeats: bool = True
-) -> gtn.Graph:
- """Constructs a graph with all the individual token transition models."""
- if not allow_repeats and blank != "optional":
- raise ValueError("Must use blank='optional' if disallowing repeats.")
-
- ntoks = len(token_list)
- graph = gtn.Graph(False)
-
- # Creating nodes
- graph.add_node(True, True)
- for i in range(ntoks):
- # We can consume one or more consecutive word
- # pieces for each emission:
- # E.g. [ab, ab, ab] transduces to [ab]
- graph.add_node(False, blank != "forced")
-
- if blank != "none":
- graph.add_node()
-
- # Creating arcs
- if blank != "none":
- # Blank index is assumed to be last (ntoks)
- graph.add_arc(0, ntoks + 1, ntoks, gtn.epsilon)
- graph.add_arc(ntoks + 1, 0, gtn.epsilon)
-
- for i in range(ntoks):
- graph.add_arc((ntoks + 1) if blank == "forced" else 0, i + 1, i)
- graph.add_arc(i + 1, i + 1, i, gtn.epsilon)
-
- if allow_repeats:
- if blank == "forced":
- # Allow transitions from token to blank only
- graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon)
- else:
- # Allow transition from token to blank and all other tokens
- graph.add_arc(i + 1, 0, gtn.epsilon)
-
- else:
- # allow transitions to blank and all other tokens except the same token
- graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon)
- for j in range(ntoks):
- if i != j:
- graph.add_arc(i + 1, j + 1, j, j)
-
- return graph
-
-
-class TransducerLossFunction(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- inputs,
- targets,
- tokens,
- lexicon,
- transition_params=None,
- transitions=None,
- reduction="none",
- ) -> Tensor:
- B, T, C = inputs.shape
-
- losses = [None] * B
- emissions_graphs = [None] * B
-
- if transitions is not None:
- if transition_params is None:
- raise ValueError("Specified transitions, but not transition params.")
-
- cpu_data = transition_params.cpu().contiguous()
- transitions.set_weights(cpu_data.data_ptr())
- transitions.calc_grad = transition_params.requires_grad
- transitions.zero_grad()
-
- def process(b: int) -> None:
- # Create emission graph:
- emissions = gtn.linear_graph(T, C, inputs.requires_grad)
- cpu_data = inputs[b].cpu().contiguous()
- emissions.set_weights(cpu_data.data_ptr())
- target = make_chain_graph(targets[b])
- target.arc_sort(True)
-
- # Create token tot grapheme decomposition graph
- tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon)))
- tokens_target.arc_sort()
-
- # Create alignment graph:
- aligments = gtn.project_input(
- gtn.remove(gtn.compose(tokens, tokens_target))
- )
- aligments.arc_sort()
-
- # Add transitions scores:
- if transitions is not None:
- aligments = gtn.intersect(transitions, aligments)
- aligments.arc_sort()
-
- loss = gtn.forward_score(gtn.intersect(emissions, aligments))
-
- # Normalize if needed:
- if transitions is not None:
- norm = gtn.forward_score(gtn.intersect(emissions, transitions))
- loss = gtn.subtract(loss, norm)
-
- losses[b] = gtn.negate(loss)
-
- # Save for backward:
- if emissions.calc_grad:
- emissions_graphs[b] = emissions
-
- gtn.parallel_for(process, range(B))
-
- ctx.graphs = (losses, emissions_graphs, transitions)
- ctx.input_shape = inputs.shape
-
- # Optionally reduce by target length
- if reduction == "mean":
- scales = [(1 / len(t) if len(t) > 0 else 1.0) for t in targets]
- else:
- scales = [1.0] * B
-
- ctx.scales = scales
-
- loss = torch.tensor([l.item() * s for l, s in zip(losses, scales)])
- return torch.mean(loss.to(inputs.device))
-
- @staticmethod
- def backward(ctx, grad_output) -> Tuple:
- losses, emissions_graphs, transitions = ctx.graphs
- scales = ctx.scales
-
- B, T, C = ctx.input_shape
- calc_emissions = ctx.needs_input_grad[0]
- input_grad = torch.empty((B, T, C)) if calc_emissions else None
-
- def process(b: int) -> None:
- scale = make_scalar_graph(scales[b])
- gtn.backward(losses[b], scale)
- emissions = emissions_graphs[b]
- if calc_emissions:
- grad = emissions.grad().weights_to_numpy()
- input_grad[b] = torch.tensor(grad).view(1, T, C)
-
- gtn.parallel_for(process, range(B))
-
- if calc_emissions:
- input_grad = input_grad.to(grad_output.device)
- input_grad *= grad_output / B
-
- if ctx.needs_input_grad[4]:
- grad = transitions.grad().weights_to_numpy()
- transition_grad = torch.tensor(grad).to(grad_output.device)
- transition_grad *= grad_output / B
- else:
- transition_grad = None
-
- return (
- input_grad,
- None, # target
- None, # tokens
- None, # lexicon
- transition_grad, # transition params
- None, # transitions graph
- None,
- )
-
-
-TransducerLoss = TransducerLossFunction.apply
-
-
-class Transducer(nn.Module):
- def __init__(
- self,
- preprocessor: Preprocessor,
- ngram: int = 0,
- transitions: str = None,
- blank: str = "none",
- allow_repeats: bool = True,
- reduction: str = "none",
- ) -> None:
- """A generic transducer loss function.
-
- Args:
- preprocessor (Preprocessor) : The IAM preprocessor for word pieces.
- ngram (int) : Order of the token-level transition model. If `ngram=0`
- then no transition model is used.
- blank (string) : Specifies the usage of blank token
- 'none' - do not use blank token
- 'optional' - allow an optional blank inbetween tokens
- 'forced' - force a blank inbetween tokens (also referred to as garbage token)
- allow_repeats (boolean) : If false, then we don't allow paths with
- consecutive tokens in the alignment graph. This keeps the graph
- unambiguous in the sense that the same input cannot transduce to
- different outputs.
- """
- super().__init__()
- if blank not in ["optional", "forced", "none"]:
- raise ValueError(
- "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']"
- )
- self.tokens = make_token_graph(
- preprocessor.tokens, blank=blank, allow_repeats=allow_repeats
- )
- self.lexicon = make_lexicon_graph(
- preprocessor.tokens,
- preprocessor.graphemes_to_index,
- preprocessor.special_tokens,
- )
- self.ngram = ngram
-
- self.transitions: Optional[gtn.Graph] = None
- self.transitions_params: Optional[nn.Parameter] = None
- self._load_transitions(transitions, preprocessor, blank)
-
- if ngram > 0 and transitions is not None:
- raise ValueError("Only one of ngram and transitions may be specified")
-
- self.reduction = reduction
-
- def _load_transitions(
- self, transitions: Optional[str], preprocessor: Preprocessor, blank: str
- ):
- """Loads transition graph."""
- processed_path = (
- Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
- )
- if transitions is not None:
- transitions = gtn.load(str(processed_path / transitions))
- if self.ngram > 0:
- self.transitions = make_transitions_graph(
- self.ngram, len(preprocessor.tokens) + int(blank != "none"), True
- )
- if transitions is not None:
- self.transitions = transitions
- self.transitions.arc_sort()
- self.transitions_params = nn.Parameter(
- torch.zeros(self.transitions.num_arcs())
- )
-
- def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss:
- return TransducerLoss(
- inputs,
- targets,
- self.tokens,
- self.lexicon,
- self.transitions_params,
- self.transitions,
- self.reduction,
- )
-
- def viterbi(self, outputs: Tensor) -> List[Tensor]:
- B, T, C = outputs.shape
-
- if self.transitions is not None:
- cpu_data = self.transition_params.cpu().contiguous()
- self.transitions.set_weights(cpu_data.data_ptr())
- self.transitions.calc_grad = False
-
- self.tokens.arc_sort()
-
- paths = [None] * B
-
- def process(b: int) -> None:
- emissions = gtn.linear_graph(T, C, False)
- cpu_data = outputs[b].cpu().contiguous()
- emissions.set_weights(cpu_data.data_ptr())
-
- if self.transitions is not None:
- full_graph = gtn.intersect(emissions, self.transitions)
- else:
- full_graph = emissions
-
- # Find the best path and remove back-off arcs:
- path = gtn.remove(gtn.viterbi_path(full_graph))
-
- # Left compose the viterbi path with the "aligment to token"
- # transducer to get the outputs:
- path = gtn.compose(path, self.tokens)
-
- # When there are ambiguous paths (allow_repeats is true), we take
- # the shortest:
- path = gtn.viterbi_path(path)
- path = gtn.remove(gtn.project_output(path))
- paths[b] = path.labels_to_list()
-
- gtn.parallel_for(process, range(B))
- predictions = [torch.IntTensor(path) for path in paths]
- return predictions
diff --git a/text_recognizer/criterions/vqgan_loss.py b/text_recognizer/criterions/vqgan_loss.py
deleted file mode 100644
index 9d1cddd..0000000
--- a/text_recognizer/criterions/vqgan_loss.py
+++ /dev/null
@@ -1,123 +0,0 @@
-"""VQGAN loss for PyTorch Lightning."""
-from typing import Optional, Tuple
-
-import torch
-from torch import nn, Tensor
-import torch.nn.functional as F
-
-from text_recognizer.criterions.n_layer_discriminator import NLayerDiscriminator
-
-
-def _adopt_weight(
- weight: Tensor, global_step: int, threshold: int = 0, value: float = 0.0
-) -> float:
- """Sets weight to value after the threshold is passed."""
- if global_step < threshold:
- weight = value
- return weight
-
-
-class VQGANLoss(nn.Module):
- """VQGAN loss."""
-
- def __init__(
- self,
- reconstruction_loss: nn.L1Loss,
- discriminator: NLayerDiscriminator,
- commitment_weight: float = 1.0,
- discriminator_weight: float = 1.0,
- discriminator_factor: float = 1.0,
- discriminator_iter_start: int = 1000,
- ) -> None:
- super().__init__()
- self.reconstruction_loss = reconstruction_loss
- self.discriminator = discriminator
- self.commitment_weight = commitment_weight
- self.discriminator_weight = discriminator_weight
- self.discriminator_factor = discriminator_factor
- self.discriminator_iter_start = discriminator_iter_start
-
- @staticmethod
- def adversarial_loss(logits_real: Tensor, logits_fake: Tensor) -> Tensor:
- """Calculates the adversarial loss."""
- loss_real = torch.mean(F.relu(1.0 - logits_real))
- loss_fake = torch.mean(F.relu(1.0 + logits_fake))
- d_loss = (loss_real + loss_fake) / 2.0
- return d_loss
-
- def _adaptive_weight(
- self, rec_loss: Tensor, g_loss: Tensor, decoder_last_layer: Tensor
- ) -> Tensor:
- rec_grads = torch.autograd.grad(
- rec_loss, decoder_last_layer, retain_graph=True
- )[0]
- g_grads = torch.autograd.grad(g_loss, decoder_last_layer, retain_graph=True)[0]
- d_weight = torch.norm(rec_grads) / (torch.norm(g_grads) + 1.0e-4)
- d_weight = torch.clamp(d_weight, 0.0, 1.0e4).detach()
- d_weight *= self.discriminator_weight
- return d_weight
-
- def forward(
- self,
- data: Tensor,
- reconstructions: Tensor,
- commitment_loss: Tensor,
- decoder_last_layer: Tensor,
- optimizer_idx: int,
- global_step: int,
- stage: str,
- ) -> Optional[Tuple]:
- """Calculates the VQGAN loss."""
- rec_loss: Tensor = self.reconstruction_loss(reconstructions, data)
-
- # GAN part.
- if optimizer_idx == 0:
- logits_fake = self.discriminator(reconstructions)
- g_loss = -torch.mean(logits_fake)
-
- if self.training:
- d_weight = self._adaptive_weight(
- rec_loss=rec_loss,
- g_loss=g_loss,
- decoder_last_layer=decoder_last_layer,
- )
- else:
- d_weight = torch.tensor(0.0)
-
- d_factor = _adopt_weight(
- self.discriminator_factor,
- global_step=global_step,
- threshold=self.discriminator_iter_start,
- )
-
- loss: Tensor = (
- rec_loss
- + d_factor * d_weight * g_loss
- + self.commitment_weight * commitment_loss
- )
- log = {
- f"{stage}/total_loss": loss,
- f"{stage}/commitment_loss": commitment_loss,
- f"{stage}/rec_loss": rec_loss,
- f"{stage}/g_loss": g_loss,
- }
- return loss, log
-
- if optimizer_idx == 1:
- logits_fake = self.discriminator(reconstructions.detach())
- logits_real = self.discriminator(data.detach())
-
- d_factor = _adopt_weight(
- self.discriminator_factor,
- global_step=global_step,
- threshold=self.discriminator_iter_start,
- )
-
- d_loss = d_factor * self.adversarial_loss(
- logits_real=logits_real, logits_fake=logits_fake
- )
-
- log = {
- f"{stage}/d_loss": d_loss,
- }
- return d_loss, log