From 01d6e5fc066969283df99c759609df441151e9c5 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 6 Jun 2021 23:19:35 +0200 Subject: Working on fixing decoder transformer --- text_recognizer/networks/transducer/__init__.py | 3 - text_recognizer/networks/transducer/tds_conv.py | 208 ----------- text_recognizer/networks/transducer/test.py | 60 ---- text_recognizer/networks/transducer/transducer.py | 410 ---------------------- 4 files changed, 681 deletions(-) delete mode 100644 text_recognizer/networks/transducer/__init__.py delete mode 100644 text_recognizer/networks/transducer/tds_conv.py delete mode 100644 text_recognizer/networks/transducer/test.py delete mode 100644 text_recognizer/networks/transducer/transducer.py (limited to 'text_recognizer/networks/transducer') diff --git a/text_recognizer/networks/transducer/__init__.py b/text_recognizer/networks/transducer/__init__.py deleted file mode 100644 index 8c19a01..0000000 --- a/text_recognizer/networks/transducer/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Transducer modules.""" -from .tds_conv import TDS2d -from .transducer import load_transducer_loss, Transducer diff --git a/text_recognizer/networks/transducer/tds_conv.py b/text_recognizer/networks/transducer/tds_conv.py deleted file mode 100644 index 5fb8ba9..0000000 --- a/text_recognizer/networks/transducer/tds_conv.py +++ /dev/null @@ -1,208 +0,0 @@ -"""Time-Depth Separable Convolutions. - -References: - https://arxiv.org/abs/1904.02619 - https://arxiv.org/pdf/2010.01003.pdf - -Code stolen from: - https://github.com/facebookresearch/gtn_applications - - -""" -from typing import List, Tuple - -from einops import rearrange -import gtn -import numpy as np -import torch -from torch import nn -from torch import Tensor - - -class TDSBlock2d(nn.Module): - """Internal block of a 2D TDSC network.""" - - def __init__( - self, - in_channels: int, - img_depth: int, - kernel_size: Tuple[int], - dropout_rate: float, - ) -> None: - super().__init__() - - self.in_channels = in_channels - self.img_depth = img_depth - self.kernel_size = kernel_size - self.dropout_rate = dropout_rate - self.fc_dim = in_channels * img_depth - - # Network placeholders. - self.conv = None - self.mlp = None - self.instance_norm = None - - self._build_block() - - def _build_block(self) -> None: - # Convolutional block. - self.conv = nn.Sequential( - nn.Conv3d( - in_channels=self.in_channels, - out_channels=self.in_channels, - kernel_size=(1, self.kernel_size[0], self.kernel_size[1]), - padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2), - ), - nn.ReLU(inplace=True), - nn.Dropout(self.dropout_rate), - ) - - # MLP block. - self.mlp = nn.Sequential( - nn.Linear(self.fc_dim, self.fc_dim), - nn.ReLU(inplace=True), - nn.Dropout(self.dropout_rate), - nn.Linear(self.fc_dim, self.fc_dim), - nn.Dropout(self.dropout_rate), - ) - - # Instance norm. - self.instance_norm = nn.ModuleList( - [ - nn.InstanceNorm2d(self.fc_dim, affine=True), - nn.InstanceNorm2d(self.fc_dim, affine=True), - ] - ) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass. - - Args: - x (Tensor): Input tensor. - - Shape: - - x: :math: `(B, CD, H, W)` - - Returns: - Tensor: Output tensor. - - """ - B, CD, H, W = x.shape - C, D = self.in_channels, self.img_depth - residual = x - x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D) - x = self.conv(x) - x = rearrange(x, "b c d h w -> b (c d) h w") - x += residual - - x = self.instance_norm[0](x) - - x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x - x + self.instance_norm[1](x) - - # Output shape: [B, CD, H, W] - return x - - -class TDS2d(nn.Module): - """TDS Netowrk. - - Structure is the following: - Downsample layer -> TDS2d group -> ... -> Linear output layer - - - """ - - def __init__( - self, - input_dim: int, - output_dim: int, - depth: int, - tds_groups: Tuple[int], - kernel_size: Tuple[int], - dropout_rate: float, - in_channels: int = 1, - ) -> None: - super().__init__() - - self.in_channels = in_channels - self.input_dim = input_dim - self.output_dim = output_dim - self.depth = depth - self.tds_groups = tds_groups - self.kernel_size = kernel_size - self.dropout_rate = dropout_rate - - self.tds = None - self.fc = None - - self._build_network() - - def _build_network(self) -> None: - in_channels = self.in_channels - modules = [] - stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups]) - if self.input_dim % stride_h: - raise RuntimeError( - f"Image height not divisible by total stride {stride_h}." - ) - - for tds_group in self.tds_groups: - # Add downsample layer. - out_channels = self.depth * tds_group["channels"] - modules.extend( - [ - nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=self.kernel_size, - padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2), - stride=tds_group["stride"], - ), - nn.ReLU(inplace=True), - nn.Dropout(self.dropout_rate), - nn.InstanceNorm2d(out_channels, affine=True), - ] - ) - - for _ in range(tds_group["num_blocks"]): - modules.append( - TDSBlock2d( - tds_group["channels"], - self.depth, - self.kernel_size, - self.dropout_rate, - ) - ) - - in_channels = out_channels - - self.tds = nn.Sequential(*modules) - self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass. - - Args: - x (Tensor): Input tensor. - - Shape: - - x: :math: `(B, H, W)` - - Returns: - Tensor: Output tensor. - - """ - if len(x.shape) == 4: - x = x.squeeze(1) # Squeeze the channel dim away. - - B, H, W = x.shape - x = rearrange( - x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels - ) - x = self.tds(x) - - # x shape: [B, C, H, W] - x = rearrange(x, "b c h w -> b w (c h)") - - return self.fc(x) diff --git a/text_recognizer/networks/transducer/test.py b/text_recognizer/networks/transducer/test.py deleted file mode 100644 index cadcecc..0000000 --- a/text_recognizer/networks/transducer/test.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch -from torch import nn - -from text_recognizer.networks.transducer import load_transducer_loss, Transducer -import unittest - - -class TestTransducer(unittest.TestCase): - def test_viterbi(self): - T = 5 - N = 4 - B = 2 - - # fmt: off - emissions1 = torch.tensor(( - 0, 4, 0, 1, - 0, 2, 1, 1, - 0, 0, 0, 2, - 0, 0, 0, 2, - 8, 0, 0, 2, - ), - dtype=torch.float, - ).view(T, N) - emissions2 = torch.tensor(( - 0, 2, 1, 7, - 0, 2, 9, 1, - 0, 0, 0, 2, - 0, 0, 5, 2, - 1, 0, 0, 2, - ), - dtype=torch.float, - ).view(T, N) - # fmt: on - - # Test without blank: - labels = [[1, 3, 0], [3, 2, 3, 2, 3]] - transducer = Transducer( - tokens=["a", "b", "c", "d"], - graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3}, - blank="none", - ) - emissions = torch.stack([emissions1, emissions2], dim=0) - predictions = transducer.viterbi(emissions) - self.assertEqual([p.tolist() for p in predictions], labels) - - # Test with blank without repeats: - labels = [[1, 0], [2, 2]] - transducer = Transducer( - tokens=["a", "b", "c"], - graphemes_to_idx={"a": 0, "b": 1, "c": 2}, - blank="optional", - allow_repeats=False, - ) - emissions = torch.stack([emissions1, emissions2], dim=0) - predictions = transducer.viterbi(emissions) - self.assertEqual([p.tolist() for p in predictions], labels) - - -if __name__ == "__main__": - unittest.main() diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py deleted file mode 100644 index d7e3d08..0000000 --- a/text_recognizer/networks/transducer/transducer.py +++ /dev/null @@ -1,410 +0,0 @@ -"""Transducer and the transducer loss function.py - -Stolen from: - https://github.com/facebookresearch/gtn_applications/blob/master/transducer.py - -""" -from pathlib import Path -import itertools -from typing import Dict, List, Optional, Union, Tuple - -from loguru import logger -import gtn -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.datasets.iam_preprocessor import Preprocessor - - -def make_scalar_graph(weight) -> gtn.Graph: - scalar = gtn.Graph() - scalar.add_node(True) - scalar.add_node(False, True) - scalar.add_arc(0, 1, 0, 0, weight) - return scalar - - -def make_chain_graph(sequence) -> gtn.Graph: - graph = gtn.Graph(False) - graph.add_node(True) - for i, s in enumerate(sequence): - graph.add_node(False, i == (len(sequence) - 1)) - graph.add_arc(i, i + 1, s) - return graph - - -def make_transitions_graph( - ngram: int, num_tokens: int, calc_grad: bool = False -) -> gtn.Graph: - transitions = gtn.Graph(calc_grad) - transitions.add_node(True, ngram == 1) - - state_map = {(): 0} - - # First build transitions which include : - for n in range(1, ngram): - for state in itertools.product(range(num_tokens), repeat=n): - in_idx = state_map[state[:-1]] - out_idx = transitions.add_node(False, ngram == 1) - state_map[state] = out_idx - transitions.add_arc(in_idx, out_idx, state[-1]) - - for state in itertools.product(range(num_tokens), repeat=ngram): - state_idx = state_map[state[:-1]] - new_state_idx = state_map[state[1:]] - # p(state[-1] | state[:-1]) - transitions.add_arc(state_idx, new_state_idx, state[-1]) - - if ngram > 1: - # Build transitions which include : - end_idx = transitions.add_node(False, True) - for in_idx in range(end_idx): - transitions.add_arc(in_idx, end_idx, gtn.epsilon) - - return transitions - - -def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph: - """Constructs a graph which transduces letters to word pieces.""" - graph = gtn.Graph(False) - graph.add_node(True, True) - for i, wp in enumerate(word_pieces): - prev = 0 - for l in wp[:-1]: - n = graph.add_node() - graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon) - prev = n - graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i) - graph.arc_sort() - return graph - - -def make_token_graph( - token_list: List, blank: str = "none", allow_repeats: bool = True -) -> gtn.Graph: - """Constructs a graph with all the individual token transition models.""" - if not allow_repeats and blank != "optional": - raise ValueError("Must use blank='optional' if disallowing repeats.") - - ntoks = len(token_list) - graph = gtn.Graph(False) - - # Creating nodes - graph.add_node(True, True) - for i in range(ntoks): - # We can consume one or more consecutive word - # pieces for each emission: - # E.g. [ab, ab, ab] transduces to [ab] - graph.add_node(False, blank != "forced") - - if blank != "none": - graph.add_node() - - # Creating arcs - if blank != "none": - # Blank index is assumed to be last (ntoks) - graph.add_arc(0, ntoks + 1, ntoks, gtn.epsilon) - graph.add_arc(ntoks + 1, 0, gtn.epsilon) - - for i in range(ntoks): - graph.add_arc((ntoks + 1) if blank == "forced" else 0, i + 1, i) - graph.add_arc(i + 1, i + 1, i, gtn.epsilon) - - if allow_repeats: - if blank == "forced": - # Allow transitions from token to blank only - graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon) - else: - # Allow transition from token to blank and all other tokens - graph.add_arc(i + 1, 0, gtn.epsilon) - - else: - # allow transitions to blank and all other tokens except the same token - graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon) - for j in range(ntoks): - if i != j: - graph.add_arc(i + 1, j + 1, j, j) - - return graph - - -class TransducerLossFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - inputs, - targets, - tokens, - lexicon, - transition_params=None, - transitions=None, - reduction="none", - ) -> Tensor: - B, T, C = inputs.shape - - losses = [None] * B - emissions_graphs = [None] * B - - if transitions is not None: - if transition_params is None: - raise ValueError("Specified transitions, but not transition params.") - - cpu_data = transition_params.cpu().contiguous() - transitions.set_weights(cpu_data.data_ptr()) - transitions.calc_grad = transition_params.requires_grad - transitions.zero_grad() - - def process(b: int) -> None: - # Create emission graph: - emissions = gtn.linear_graph(T, C, inputs.requires_grad) - cpu_data = inputs[b].cpu().contiguous() - emissions.set_weights(cpu_data.data_ptr()) - target = make_chain_graph(targets[b]) - target.arc_sort(True) - - # Create token tot grapheme decomposition graph - tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon))) - tokens_target.arc_sort() - - # Create alignment graph: - aligments = gtn.project_input( - gtn.remove(gtn.compose(tokens, tokens_target)) - ) - aligments.arc_sort() - - # Add transitions scores: - if transitions is not None: - aligments = gtn.intersect(transitions, aligments) - aligments.arc_sort() - - loss = gtn.forward_score(gtn.intersect(emissions, aligments)) - - # Normalize if needed: - if transitions is not None: - norm = gtn.forward_score(gtn.intersect(emissions, transitions)) - loss = gtn.subtract(loss, norm) - - losses[b] = gtn.negate(loss) - - # Save for backward: - if emissions.calc_grad: - emissions_graphs[b] = emissions - - gtn.parallel_for(process, range(B)) - - ctx.graphs = (losses, emissions_graphs, transitions) - ctx.input_shape = inputs.shape - - # Optionally reduce by target length - if reduction == "mean": - scales = [(1 / len(t) if len(t) > 0 else 1.0) for t in targets] - else: - scales = [1.0] * B - - ctx.scales = scales - - loss = torch.tensor([l.item() * s for l, s in zip(losses, scales)]) - return torch.mean(loss.to(inputs.device)) - - @staticmethod - def backward(ctx, grad_output) -> Tuple: - losses, emissions_graphs, transitions = ctx.graphs - scales = ctx.scales - - B, T, C = ctx.input_shape - calc_emissions = ctx.needs_input_grad[0] - input_grad = torch.empty((B, T, C)) if calc_emissions else None - - def process(b: int) -> None: - scale = make_scalar_graph(scales[b]) - gtn.backward(losses[b], scale) - emissions = emissions_graphs[b] - if calc_emissions: - grad = emissions.grad().weights_to_numpy() - input_grad[b] = torch.tensor(grad).view(1, T, C) - - gtn.parallel_for(process, range(B)) - - if calc_emissions: - input_grad = input_grad.to(grad_output.device) - input_grad *= grad_output / B - - if ctx.needs_input_grad[4]: - grad = transitions.grad().weights_to_numpy() - transition_grad = torch.tensor(grad).to(grad_output.device) - transition_grad *= grad_output / B - else: - transition_grad = None - - return ( - input_grad, - None, # target - None, # tokens - None, # lexicon - transition_grad, # transition params - None, # transitions graph - None, - ) - - -TransducerLoss = TransducerLossFunction.apply - - -class Transducer(nn.Module): - def __init__( - self, - tokens: List, - graphemes_to_idx: Dict, - ngram: int = 0, - transitions: str = None, - blank: str = "none", - allow_repeats: bool = True, - reduction: str = "none", - ) -> None: - """A generic transducer loss function. - - Args: - tokens (List) : A list of iterable objects (e.g. strings, tuples, etc) - representing the output tokens of the model (e.g. letters, - word-pieces, words). For example ["a", "b", "ab", "ba", "aba"] - could be a list of sub-word tokens. - graphemes_to_idx (dict) : A dictionary mapping grapheme units (e.g. - "a", "b", ..) to their corresponding integer index. - ngram (int) : Order of the token-level transition model. If `ngram=0` - then no transition model is used. - blank (string) : Specifies the usage of blank token - 'none' - do not use blank token - 'optional' - allow an optional blank inbetween tokens - 'forced' - force a blank inbetween tokens (also referred to as garbage token) - allow_repeats (boolean) : If false, then we don't allow paths with - consecutive tokens in the alignment graph. This keeps the graph - unambiguous in the sense that the same input cannot transduce to - different outputs. - """ - super().__init__() - if blank not in ["optional", "forced", "none"]: - raise ValueError( - "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']" - ) - self.tokens = make_token_graph(tokens, blank=blank, allow_repeats=allow_repeats) - self.lexicon = make_lexicon_graph(tokens, graphemes_to_idx) - self.ngram = ngram - if ngram > 0 and transitions is not None: - raise ValueError("Only one of ngram and transitions may be specified") - - if ngram > 0: - transitions = make_transitions_graph( - ngram, len(tokens) + int(blank != "none"), True - ) - - if transitions is not None: - self.transitions = transitions - self.transitions.arc_sort() - self.transitions_params = nn.Parameter( - torch.zeros(self.transitions.num_arcs()) - ) - else: - self.transitions = None - self.transitions_params = None - self.reduction = reduction - - def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss: - TransducerLoss( - inputs, - targets, - self.tokens, - self.lexicon, - self.transitions_params, - self.transitions, - self.reduction, - ) - - def viterbi(self, outputs: Tensor) -> List[Tensor]: - B, T, C = outputs.shape - - if self.transitions is not None: - cpu_data = self.transition_params.cpu().contiguous() - self.transitions.set_weights(cpu_data.data_ptr()) - self.transitions.calc_grad = False - - self.tokens.arc_sort() - - paths = [None] * B - - def process(b: int) -> None: - emissions = gtn.linear_graph(T, C, False) - cpu_data = outputs[b].cpu().contiguous() - emissions.set_weights(cpu_data.data_ptr()) - - if self.transitions is not None: - full_graph = gtn.intersect(emissions, self.transitions) - else: - full_graph = emissions - - # Find the best path and remove back-off arcs: - path = gtn.remove(gtn.viterbi_path(full_graph)) - - # Left compose the viterbi path with the "aligment to token" - # transducer to get the outputs: - path = gtn.compose(path, self.tokens) - - # When there are ambiguous paths (allow_repeats is true), we take - # the shortest: - path = gtn.viterbi_path(path) - path = gtn.remove(gtn.project_output(path)) - paths[b] = path.labels_to_list() - - gtn.parallel_for(process, range(B)) - predictions = [torch.IntTensor(path) for path in paths] - return predictions - - -def load_transducer_loss( - num_features: int, - ngram: int, - tokens: str, - lexicon: str, - transitions: str, - blank: str, - allow_repeats: bool, - prepend_wordsep: bool = False, - use_words: bool = False, - data_dir: Optional[Union[str, Path]] = None, - reduction: str = "mean", -) -> Tuple[Transducer, int]: - if data_dir is None: - data_dir = ( - Path(__file__).resolve().parents[4] / "data" / "raw" / "iam" / "iamdb" - ) - logger.debug(f"Using data dir: {data_dir}") - if not data_dir.exists(): - raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") - else: - data_dir = Path(data_dir) - processed_path = ( - Path(__file__).resolve().parents[4] / "data" / "processed" / "iam_lines" - ) - tokens_path = processed_path / tokens - lexicon_path = processed_path / lexicon - - if transitions is not None: - transitions = gtn.load(str(processed_path / transitions)) - - preprocessor = Preprocessor( - data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep, - ) - - num_tokens = preprocessor.num_tokens - - criterion = Transducer( - preprocessor.tokens, - preprocessor.graphemes_to_index, - ngram=ngram, - transitions=transitions, - blank=blank, - allow_repeats=allow_repeats, - reduction=reduction, - ) - - return criterion, num_tokens + int(blank != "none") -- cgit v1.2.3-70-g09d2