diff options
Diffstat (limited to 'text_recognizer/networks/transducer')
-rw-r--r-- | text_recognizer/networks/transducer/__init__.py | 3 | ||||
-rw-r--r-- | text_recognizer/networks/transducer/tds_conv.py | 208 | ||||
-rw-r--r-- | text_recognizer/networks/transducer/test.py | 60 | ||||
-rw-r--r-- | text_recognizer/networks/transducer/transducer.py | 410 |
4 files changed, 681 insertions, 0 deletions
diff --git a/text_recognizer/networks/transducer/__init__.py b/text_recognizer/networks/transducer/__init__.py new file mode 100644 index 0000000..8c19a01 --- /dev/null +++ b/text_recognizer/networks/transducer/__init__.py @@ -0,0 +1,3 @@ +"""Transducer modules.""" +from .tds_conv import TDS2d +from .transducer import load_transducer_loss, Transducer diff --git a/text_recognizer/networks/transducer/tds_conv.py b/text_recognizer/networks/transducer/tds_conv.py new file mode 100644 index 0000000..5fb8ba9 --- /dev/null +++ b/text_recognizer/networks/transducer/tds_conv.py @@ -0,0 +1,208 @@ +"""Time-Depth Separable Convolutions. + +References: + https://arxiv.org/abs/1904.02619 + https://arxiv.org/pdf/2010.01003.pdf + +Code stolen from: + https://github.com/facebookresearch/gtn_applications + + +""" +from typing import List, Tuple + +from einops import rearrange +import gtn +import numpy as np +import torch +from torch import nn +from torch import Tensor + + +class TDSBlock2d(nn.Module): + """Internal block of a 2D TDSC network.""" + + def __init__( + self, + in_channels: int, + img_depth: int, + kernel_size: Tuple[int], + dropout_rate: float, + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.img_depth = img_depth + self.kernel_size = kernel_size + self.dropout_rate = dropout_rate + self.fc_dim = in_channels * img_depth + + # Network placeholders. + self.conv = None + self.mlp = None + self.instance_norm = None + + self._build_block() + + def _build_block(self) -> None: + # Convolutional block. + self.conv = nn.Sequential( + nn.Conv3d( + in_channels=self.in_channels, + out_channels=self.in_channels, + kernel_size=(1, self.kernel_size[0], self.kernel_size[1]), + padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2), + ), + nn.ReLU(inplace=True), + nn.Dropout(self.dropout_rate), + ) + + # MLP block. + self.mlp = nn.Sequential( + nn.Linear(self.fc_dim, self.fc_dim), + nn.ReLU(inplace=True), + nn.Dropout(self.dropout_rate), + nn.Linear(self.fc_dim, self.fc_dim), + nn.Dropout(self.dropout_rate), + ) + + # Instance norm. + self.instance_norm = nn.ModuleList( + [ + nn.InstanceNorm2d(self.fc_dim, affine=True), + nn.InstanceNorm2d(self.fc_dim, affine=True), + ] + ) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass. + + Args: + x (Tensor): Input tensor. + + Shape: + - x: :math: `(B, CD, H, W)` + + Returns: + Tensor: Output tensor. + + """ + B, CD, H, W = x.shape + C, D = self.in_channels, self.img_depth + residual = x + x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D) + x = self.conv(x) + x = rearrange(x, "b c d h w -> b (c d) h w") + x += residual + + x = self.instance_norm[0](x) + + x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x + x + self.instance_norm[1](x) + + # Output shape: [B, CD, H, W] + return x + + +class TDS2d(nn.Module): + """TDS Netowrk. + + Structure is the following: + Downsample layer -> TDS2d group -> ... -> Linear output layer + + + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + depth: int, + tds_groups: Tuple[int], + kernel_size: Tuple[int], + dropout_rate: float, + in_channels: int = 1, + ) -> None: + super().__init__() + + self.in_channels = in_channels + self.input_dim = input_dim + self.output_dim = output_dim + self.depth = depth + self.tds_groups = tds_groups + self.kernel_size = kernel_size + self.dropout_rate = dropout_rate + + self.tds = None + self.fc = None + + self._build_network() + + def _build_network(self) -> None: + in_channels = self.in_channels + modules = [] + stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups]) + if self.input_dim % stride_h: + raise RuntimeError( + f"Image height not divisible by total stride {stride_h}." + ) + + for tds_group in self.tds_groups: + # Add downsample layer. + out_channels = self.depth * tds_group["channels"] + modules.extend( + [ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=self.kernel_size, + padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2), + stride=tds_group["stride"], + ), + nn.ReLU(inplace=True), + nn.Dropout(self.dropout_rate), + nn.InstanceNorm2d(out_channels, affine=True), + ] + ) + + for _ in range(tds_group["num_blocks"]): + modules.append( + TDSBlock2d( + tds_group["channels"], + self.depth, + self.kernel_size, + self.dropout_rate, + ) + ) + + in_channels = out_channels + + self.tds = nn.Sequential(*modules) + self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass. + + Args: + x (Tensor): Input tensor. + + Shape: + - x: :math: `(B, H, W)` + + Returns: + Tensor: Output tensor. + + """ + if len(x.shape) == 4: + x = x.squeeze(1) # Squeeze the channel dim away. + + B, H, W = x.shape + x = rearrange( + x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels + ) + x = self.tds(x) + + # x shape: [B, C, H, W] + x = rearrange(x, "b c h w -> b w (c h)") + + return self.fc(x) diff --git a/text_recognizer/networks/transducer/test.py b/text_recognizer/networks/transducer/test.py new file mode 100644 index 0000000..cadcecc --- /dev/null +++ b/text_recognizer/networks/transducer/test.py @@ -0,0 +1,60 @@ +import torch +from torch import nn + +from text_recognizer.networks.transducer import load_transducer_loss, Transducer +import unittest + + +class TestTransducer(unittest.TestCase): + def test_viterbi(self): + T = 5 + N = 4 + B = 2 + + # fmt: off + emissions1 = torch.tensor(( + 0, 4, 0, 1, + 0, 2, 1, 1, + 0, 0, 0, 2, + 0, 0, 0, 2, + 8, 0, 0, 2, + ), + dtype=torch.float, + ).view(T, N) + emissions2 = torch.tensor(( + 0, 2, 1, 7, + 0, 2, 9, 1, + 0, 0, 0, 2, + 0, 0, 5, 2, + 1, 0, 0, 2, + ), + dtype=torch.float, + ).view(T, N) + # fmt: on + + # Test without blank: + labels = [[1, 3, 0], [3, 2, 3, 2, 3]] + transducer = Transducer( + tokens=["a", "b", "c", "d"], + graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3}, + blank="none", + ) + emissions = torch.stack([emissions1, emissions2], dim=0) + predictions = transducer.viterbi(emissions) + self.assertEqual([p.tolist() for p in predictions], labels) + + # Test with blank without repeats: + labels = [[1, 0], [2, 2]] + transducer = Transducer( + tokens=["a", "b", "c"], + graphemes_to_idx={"a": 0, "b": 1, "c": 2}, + blank="optional", + allow_repeats=False, + ) + emissions = torch.stack([emissions1, emissions2], dim=0) + predictions = transducer.viterbi(emissions) + self.assertEqual([p.tolist() for p in predictions], labels) + + +if __name__ == "__main__": + unittest.main() diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py new file mode 100644 index 0000000..d7e3d08 --- /dev/null +++ b/text_recognizer/networks/transducer/transducer.py @@ -0,0 +1,410 @@ +"""Transducer and the transducer loss function.py + +Stolen from: + https://github.com/facebookresearch/gtn_applications/blob/master/transducer.py + +""" +from pathlib import Path +import itertools +from typing import Dict, List, Optional, Union, Tuple + +from loguru import logger +import gtn +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.datasets.iam_preprocessor import Preprocessor + + +def make_scalar_graph(weight) -> gtn.Graph: + scalar = gtn.Graph() + scalar.add_node(True) + scalar.add_node(False, True) + scalar.add_arc(0, 1, 0, 0, weight) + return scalar + + +def make_chain_graph(sequence) -> gtn.Graph: + graph = gtn.Graph(False) + graph.add_node(True) + for i, s in enumerate(sequence): + graph.add_node(False, i == (len(sequence) - 1)) + graph.add_arc(i, i + 1, s) + return graph + + +def make_transitions_graph( + ngram: int, num_tokens: int, calc_grad: bool = False +) -> gtn.Graph: + transitions = gtn.Graph(calc_grad) + transitions.add_node(True, ngram == 1) + + state_map = {(): 0} + + # First build transitions which include <s>: + for n in range(1, ngram): + for state in itertools.product(range(num_tokens), repeat=n): + in_idx = state_map[state[:-1]] + out_idx = transitions.add_node(False, ngram == 1) + state_map[state] = out_idx + transitions.add_arc(in_idx, out_idx, state[-1]) + + for state in itertools.product(range(num_tokens), repeat=ngram): + state_idx = state_map[state[:-1]] + new_state_idx = state_map[state[1:]] + # p(state[-1] | state[:-1]) + transitions.add_arc(state_idx, new_state_idx, state[-1]) + + if ngram > 1: + # Build transitions which include </s>: + end_idx = transitions.add_node(False, True) + for in_idx in range(end_idx): + transitions.add_arc(in_idx, end_idx, gtn.epsilon) + + return transitions + + +def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph: + """Constructs a graph which transduces letters to word pieces.""" + graph = gtn.Graph(False) + graph.add_node(True, True) + for i, wp in enumerate(word_pieces): + prev = 0 + for l in wp[:-1]: + n = graph.add_node() + graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon) + prev = n + graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i) + graph.arc_sort() + return graph + + +def make_token_graph( + token_list: List, blank: str = "none", allow_repeats: bool = True +) -> gtn.Graph: + """Constructs a graph with all the individual token transition models.""" + if not allow_repeats and blank != "optional": + raise ValueError("Must use blank='optional' if disallowing repeats.") + + ntoks = len(token_list) + graph = gtn.Graph(False) + + # Creating nodes + graph.add_node(True, True) + for i in range(ntoks): + # We can consume one or more consecutive word + # pieces for each emission: + # E.g. [ab, ab, ab] transduces to [ab] + graph.add_node(False, blank != "forced") + + if blank != "none": + graph.add_node() + + # Creating arcs + if blank != "none": + # Blank index is assumed to be last (ntoks) + graph.add_arc(0, ntoks + 1, ntoks, gtn.epsilon) + graph.add_arc(ntoks + 1, 0, gtn.epsilon) + + for i in range(ntoks): + graph.add_arc((ntoks + 1) if blank == "forced" else 0, i + 1, i) + graph.add_arc(i + 1, i + 1, i, gtn.epsilon) + + if allow_repeats: + if blank == "forced": + # Allow transitions from token to blank only + graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon) + else: + # Allow transition from token to blank and all other tokens + graph.add_arc(i + 1, 0, gtn.epsilon) + + else: + # allow transitions to blank and all other tokens except the same token + graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon) + for j in range(ntoks): + if i != j: + graph.add_arc(i + 1, j + 1, j, j) + + return graph + + +class TransducerLossFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + inputs, + targets, + tokens, + lexicon, + transition_params=None, + transitions=None, + reduction="none", + ) -> Tensor: + B, T, C = inputs.shape + + losses = [None] * B + emissions_graphs = [None] * B + + if transitions is not None: + if transition_params is None: + raise ValueError("Specified transitions, but not transition params.") + + cpu_data = transition_params.cpu().contiguous() + transitions.set_weights(cpu_data.data_ptr()) + transitions.calc_grad = transition_params.requires_grad + transitions.zero_grad() + + def process(b: int) -> None: + # Create emission graph: + emissions = gtn.linear_graph(T, C, inputs.requires_grad) + cpu_data = inputs[b].cpu().contiguous() + emissions.set_weights(cpu_data.data_ptr()) + target = make_chain_graph(targets[b]) + target.arc_sort(True) + + # Create token tot grapheme decomposition graph + tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon))) + tokens_target.arc_sort() + + # Create alignment graph: + aligments = gtn.project_input( + gtn.remove(gtn.compose(tokens, tokens_target)) + ) + aligments.arc_sort() + + # Add transitions scores: + if transitions is not None: + aligments = gtn.intersect(transitions, aligments) + aligments.arc_sort() + + loss = gtn.forward_score(gtn.intersect(emissions, aligments)) + + # Normalize if needed: + if transitions is not None: + norm = gtn.forward_score(gtn.intersect(emissions, transitions)) + loss = gtn.subtract(loss, norm) + + losses[b] = gtn.negate(loss) + + # Save for backward: + if emissions.calc_grad: + emissions_graphs[b] = emissions + + gtn.parallel_for(process, range(B)) + + ctx.graphs = (losses, emissions_graphs, transitions) + ctx.input_shape = inputs.shape + + # Optionally reduce by target length + if reduction == "mean": + scales = [(1 / len(t) if len(t) > 0 else 1.0) for t in targets] + else: + scales = [1.0] * B + + ctx.scales = scales + + loss = torch.tensor([l.item() * s for l, s in zip(losses, scales)]) + return torch.mean(loss.to(inputs.device)) + + @staticmethod + def backward(ctx, grad_output) -> Tuple: + losses, emissions_graphs, transitions = ctx.graphs + scales = ctx.scales + + B, T, C = ctx.input_shape + calc_emissions = ctx.needs_input_grad[0] + input_grad = torch.empty((B, T, C)) if calc_emissions else None + + def process(b: int) -> None: + scale = make_scalar_graph(scales[b]) + gtn.backward(losses[b], scale) + emissions = emissions_graphs[b] + if calc_emissions: + grad = emissions.grad().weights_to_numpy() + input_grad[b] = torch.tensor(grad).view(1, T, C) + + gtn.parallel_for(process, range(B)) + + if calc_emissions: + input_grad = input_grad.to(grad_output.device) + input_grad *= grad_output / B + + if ctx.needs_input_grad[4]: + grad = transitions.grad().weights_to_numpy() + transition_grad = torch.tensor(grad).to(grad_output.device) + transition_grad *= grad_output / B + else: + transition_grad = None + + return ( + input_grad, + None, # target + None, # tokens + None, # lexicon + transition_grad, # transition params + None, # transitions graph + None, + ) + + +TransducerLoss = TransducerLossFunction.apply + + +class Transducer(nn.Module): + def __init__( + self, + tokens: List, + graphemes_to_idx: Dict, + ngram: int = 0, + transitions: str = None, + blank: str = "none", + allow_repeats: bool = True, + reduction: str = "none", + ) -> None: + """A generic transducer loss function. + + Args: + tokens (List) : A list of iterable objects (e.g. strings, tuples, etc) + representing the output tokens of the model (e.g. letters, + word-pieces, words). For example ["a", "b", "ab", "ba", "aba"] + could be a list of sub-word tokens. + graphemes_to_idx (dict) : A dictionary mapping grapheme units (e.g. + "a", "b", ..) to their corresponding integer index. + ngram (int) : Order of the token-level transition model. If `ngram=0` + then no transition model is used. + blank (string) : Specifies the usage of blank token + 'none' - do not use blank token + 'optional' - allow an optional blank inbetween tokens + 'forced' - force a blank inbetween tokens (also referred to as garbage token) + allow_repeats (boolean) : If false, then we don't allow paths with + consecutive tokens in the alignment graph. This keeps the graph + unambiguous in the sense that the same input cannot transduce to + different outputs. + """ + super().__init__() + if blank not in ["optional", "forced", "none"]: + raise ValueError( + "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']" + ) + self.tokens = make_token_graph(tokens, blank=blank, allow_repeats=allow_repeats) + self.lexicon = make_lexicon_graph(tokens, graphemes_to_idx) + self.ngram = ngram + if ngram > 0 and transitions is not None: + raise ValueError("Only one of ngram and transitions may be specified") + + if ngram > 0: + transitions = make_transitions_graph( + ngram, len(tokens) + int(blank != "none"), True + ) + + if transitions is not None: + self.transitions = transitions + self.transitions.arc_sort() + self.transitions_params = nn.Parameter( + torch.zeros(self.transitions.num_arcs()) + ) + else: + self.transitions = None + self.transitions_params = None + self.reduction = reduction + + def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss: + TransducerLoss( + inputs, + targets, + self.tokens, + self.lexicon, + self.transitions_params, + self.transitions, + self.reduction, + ) + + def viterbi(self, outputs: Tensor) -> List[Tensor]: + B, T, C = outputs.shape + + if self.transitions is not None: + cpu_data = self.transition_params.cpu().contiguous() + self.transitions.set_weights(cpu_data.data_ptr()) + self.transitions.calc_grad = False + + self.tokens.arc_sort() + + paths = [None] * B + + def process(b: int) -> None: + emissions = gtn.linear_graph(T, C, False) + cpu_data = outputs[b].cpu().contiguous() + emissions.set_weights(cpu_data.data_ptr()) + + if self.transitions is not None: + full_graph = gtn.intersect(emissions, self.transitions) + else: + full_graph = emissions + + # Find the best path and remove back-off arcs: + path = gtn.remove(gtn.viterbi_path(full_graph)) + + # Left compose the viterbi path with the "aligment to token" + # transducer to get the outputs: + path = gtn.compose(path, self.tokens) + + # When there are ambiguous paths (allow_repeats is true), we take + # the shortest: + path = gtn.viterbi_path(path) + path = gtn.remove(gtn.project_output(path)) + paths[b] = path.labels_to_list() + + gtn.parallel_for(process, range(B)) + predictions = [torch.IntTensor(path) for path in paths] + return predictions + + +def load_transducer_loss( + num_features: int, + ngram: int, + tokens: str, + lexicon: str, + transitions: str, + blank: str, + allow_repeats: bool, + prepend_wordsep: bool = False, + use_words: bool = False, + data_dir: Optional[Union[str, Path]] = None, + reduction: str = "mean", +) -> Tuple[Transducer, int]: + if data_dir is None: + data_dir = ( + Path(__file__).resolve().parents[4] / "data" / "raw" / "iam" / "iamdb" + ) + logger.debug(f"Using data dir: {data_dir}") + if not data_dir.exists(): + raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") + else: + data_dir = Path(data_dir) + processed_path = ( + Path(__file__).resolve().parents[4] / "data" / "processed" / "iam_lines" + ) + tokens_path = processed_path / tokens + lexicon_path = processed_path / lexicon + + if transitions is not None: + transitions = gtn.load(str(processed_path / transitions)) + + preprocessor = Preprocessor( + data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep, + ) + + num_tokens = preprocessor.num_tokens + + criterion = Transducer( + preprocessor.tokens, + preprocessor.graphemes_to_index, + ngram=ngram, + transitions=transitions, + blank=blank, + allow_repeats=allow_repeats, + reduction=reduction, + ) + + return criterion, num_tokens + int(blank != "none") |