diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-06-06 23:19:35 +0200 | 
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-06-06 23:19:35 +0200 | 
| commit | 01d6e5fc066969283df99c759609df441151e9c5 (patch) | |
| tree | ecd1459e142356d0c7f50a61307b760aca813248 /text_recognizer/networks/transducer | |
| parent | f4688482b4898c0b342d6ae59839dc27fbf856c6 (diff) | |
Working on fixing decoder transformer
Diffstat (limited to 'text_recognizer/networks/transducer')
| -rw-r--r-- | text_recognizer/networks/transducer/__init__.py | 3 | ||||
| -rw-r--r-- | text_recognizer/networks/transducer/tds_conv.py | 208 | ||||
| -rw-r--r-- | text_recognizer/networks/transducer/test.py | 60 | ||||
| -rw-r--r-- | text_recognizer/networks/transducer/transducer.py | 410 | 
4 files changed, 0 insertions, 681 deletions
| diff --git a/text_recognizer/networks/transducer/__init__.py b/text_recognizer/networks/transducer/__init__.py deleted file mode 100644 index 8c19a01..0000000 --- a/text_recognizer/networks/transducer/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Transducer modules.""" -from .tds_conv import TDS2d -from .transducer import load_transducer_loss, Transducer diff --git a/text_recognizer/networks/transducer/tds_conv.py b/text_recognizer/networks/transducer/tds_conv.py deleted file mode 100644 index 5fb8ba9..0000000 --- a/text_recognizer/networks/transducer/tds_conv.py +++ /dev/null @@ -1,208 +0,0 @@ -"""Time-Depth Separable Convolutions. - -References: -    https://arxiv.org/abs/1904.02619 -    https://arxiv.org/pdf/2010.01003.pdf - -Code stolen from: -    https://github.com/facebookresearch/gtn_applications - - -""" -from typing import List, Tuple - -from einops import rearrange -import gtn -import numpy as np -import torch -from torch import nn -from torch import Tensor - - -class TDSBlock2d(nn.Module): -    """Internal block of a 2D TDSC network.""" - -    def __init__( -        self, -        in_channels: int, -        img_depth: int, -        kernel_size: Tuple[int], -        dropout_rate: float, -    ) -> None: -        super().__init__() - -        self.in_channels = in_channels -        self.img_depth = img_depth -        self.kernel_size = kernel_size -        self.dropout_rate = dropout_rate -        self.fc_dim = in_channels * img_depth - -        # Network placeholders. -        self.conv = None -        self.mlp = None -        self.instance_norm = None - -        self._build_block() - -    def _build_block(self) -> None: -        # Convolutional block. -        self.conv = nn.Sequential( -            nn.Conv3d( -                in_channels=self.in_channels, -                out_channels=self.in_channels, -                kernel_size=(1, self.kernel_size[0], self.kernel_size[1]), -                padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2), -            ), -            nn.ReLU(inplace=True), -            nn.Dropout(self.dropout_rate), -        ) - -        # MLP block. -        self.mlp = nn.Sequential( -            nn.Linear(self.fc_dim, self.fc_dim), -            nn.ReLU(inplace=True), -            nn.Dropout(self.dropout_rate), -            nn.Linear(self.fc_dim, self.fc_dim), -            nn.Dropout(self.dropout_rate), -        ) - -        # Instance norm. -        self.instance_norm = nn.ModuleList( -            [ -                nn.InstanceNorm2d(self.fc_dim, affine=True), -                nn.InstanceNorm2d(self.fc_dim, affine=True), -            ] -        ) - -    def forward(self, x: Tensor) -> Tensor: -        """Forward pass. - -        Args: -            x (Tensor): Input tensor. - -        Shape: -            - x: :math: `(B, CD, H, W)` - -        Returns: -            Tensor: Output tensor. - -        """ -        B, CD, H, W = x.shape -        C, D = self.in_channels, self.img_depth -        residual = x -        x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D) -        x = self.conv(x) -        x = rearrange(x, "b c d h w -> b (c d) h w") -        x += residual - -        x = self.instance_norm[0](x) - -        x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x -        x + self.instance_norm[1](x) - -        # Output shape: [B, CD, H, W] -        return x - - -class TDS2d(nn.Module): -    """TDS Netowrk. - -    Structure is the following: -        Downsample layer -> TDS2d group -> ... -> Linear output layer - - -    """ - -    def __init__( -        self, -        input_dim: int, -        output_dim: int, -        depth: int, -        tds_groups: Tuple[int], -        kernel_size: Tuple[int], -        dropout_rate: float, -        in_channels: int = 1, -    ) -> None: -        super().__init__() - -        self.in_channels = in_channels -        self.input_dim = input_dim -        self.output_dim = output_dim -        self.depth = depth -        self.tds_groups = tds_groups -        self.kernel_size = kernel_size -        self.dropout_rate = dropout_rate - -        self.tds = None -        self.fc = None - -        self._build_network() - -    def _build_network(self) -> None: -        in_channels = self.in_channels -        modules = [] -        stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups]) -        if self.input_dim % stride_h: -            raise RuntimeError( -                f"Image height not divisible by total stride {stride_h}." -            ) - -        for tds_group in self.tds_groups: -            # Add downsample layer. -            out_channels = self.depth * tds_group["channels"] -            modules.extend( -                [ -                    nn.Conv2d( -                        in_channels=in_channels, -                        out_channels=out_channels, -                        kernel_size=self.kernel_size, -                        padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2), -                        stride=tds_group["stride"], -                    ), -                    nn.ReLU(inplace=True), -                    nn.Dropout(self.dropout_rate), -                    nn.InstanceNorm2d(out_channels, affine=True), -                ] -            ) - -            for _ in range(tds_group["num_blocks"]): -                modules.append( -                    TDSBlock2d( -                        tds_group["channels"], -                        self.depth, -                        self.kernel_size, -                        self.dropout_rate, -                    ) -                ) - -            in_channels = out_channels - -        self.tds = nn.Sequential(*modules) -        self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim) - -    def forward(self, x: Tensor) -> Tensor: -        """Forward pass. - -        Args: -            x (Tensor): Input tensor. - -        Shape: -            - x: :math: `(B, H, W)` - -        Returns: -            Tensor: Output tensor. - -        """ -        if len(x.shape) == 4: -            x = x.squeeze(1)  # Squeeze the channel dim away. - -        B, H, W = x.shape -        x = rearrange( -            x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels -        ) -        x = self.tds(x) - -        # x shape: [B, C, H, W] -        x = rearrange(x, "b c h w -> b w (c h)") - -        return self.fc(x) diff --git a/text_recognizer/networks/transducer/test.py b/text_recognizer/networks/transducer/test.py deleted file mode 100644 index cadcecc..0000000 --- a/text_recognizer/networks/transducer/test.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch -from torch import nn - -from text_recognizer.networks.transducer import load_transducer_loss, Transducer -import unittest - - -class TestTransducer(unittest.TestCase): -    def test_viterbi(self): -        T = 5 -        N = 4 -        B = 2 - -        # fmt: off -        emissions1 = torch.tensor(( -            0, 4, 0, 1, -            0, 2, 1, 1, -            0, 0, 0, 2, -            0, 0, 0, 2, -            8, 0, 0, 2, -            ), -            dtype=torch.float, -        ).view(T, N) -        emissions2 = torch.tensor(( -            0, 2, 1, 7, -            0, 2, 9, 1, -            0, 0, 0, 2, -            0, 0, 5, 2, -            1, 0, 0, 2, -            ), -            dtype=torch.float, -        ).view(T, N) -        # fmt: on - -        # Test without blank: -        labels = [[1, 3, 0], [3, 2, 3, 2, 3]] -        transducer = Transducer( -            tokens=["a", "b", "c", "d"], -            graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3}, -            blank="none", -        ) -        emissions = torch.stack([emissions1, emissions2], dim=0) -        predictions = transducer.viterbi(emissions) -        self.assertEqual([p.tolist() for p in predictions], labels) - -        # Test with blank without repeats: -        labels = [[1, 0], [2, 2]] -        transducer = Transducer( -            tokens=["a", "b", "c"], -            graphemes_to_idx={"a": 0, "b": 1, "c": 2}, -            blank="optional", -            allow_repeats=False, -        ) -        emissions = torch.stack([emissions1, emissions2], dim=0) -        predictions = transducer.viterbi(emissions) -        self.assertEqual([p.tolist() for p in predictions], labels) - - -if __name__ == "__main__": -    unittest.main() diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py deleted file mode 100644 index d7e3d08..0000000 --- a/text_recognizer/networks/transducer/transducer.py +++ /dev/null @@ -1,410 +0,0 @@ -"""Transducer and the transducer loss function.py - -Stolen from: -    https://github.com/facebookresearch/gtn_applications/blob/master/transducer.py - -""" -from pathlib import Path -import itertools -from typing import Dict, List, Optional, Union, Tuple - -from loguru import logger -import gtn -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.datasets.iam_preprocessor import Preprocessor - - -def make_scalar_graph(weight) -> gtn.Graph: -    scalar = gtn.Graph() -    scalar.add_node(True) -    scalar.add_node(False, True) -    scalar.add_arc(0, 1, 0, 0, weight) -    return scalar - - -def make_chain_graph(sequence) -> gtn.Graph: -    graph = gtn.Graph(False) -    graph.add_node(True) -    for i, s in enumerate(sequence): -        graph.add_node(False, i == (len(sequence) - 1)) -        graph.add_arc(i, i + 1, s) -    return graph - - -def make_transitions_graph( -    ngram: int, num_tokens: int, calc_grad: bool = False -) -> gtn.Graph: -    transitions = gtn.Graph(calc_grad) -    transitions.add_node(True, ngram == 1) - -    state_map = {(): 0} - -    # First build transitions which include <s>: -    for n in range(1, ngram): -        for state in itertools.product(range(num_tokens), repeat=n): -            in_idx = state_map[state[:-1]] -            out_idx = transitions.add_node(False, ngram == 1) -            state_map[state] = out_idx -            transitions.add_arc(in_idx, out_idx, state[-1]) - -    for state in itertools.product(range(num_tokens), repeat=ngram): -        state_idx = state_map[state[:-1]] -        new_state_idx = state_map[state[1:]] -        # p(state[-1] | state[:-1]) -        transitions.add_arc(state_idx, new_state_idx, state[-1]) - -    if ngram > 1: -        # Build transitions which include </s>: -        end_idx = transitions.add_node(False, True) -        for in_idx in range(end_idx): -            transitions.add_arc(in_idx, end_idx, gtn.epsilon) - -    return transitions - - -def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph: -    """Constructs a graph which transduces letters to word pieces.""" -    graph = gtn.Graph(False) -    graph.add_node(True, True) -    for i, wp in enumerate(word_pieces): -        prev = 0 -        for l in wp[:-1]: -            n = graph.add_node() -            graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon) -            prev = n -        graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i) -    graph.arc_sort() -    return graph - - -def make_token_graph( -    token_list: List, blank: str = "none", allow_repeats: bool = True -) -> gtn.Graph: -    """Constructs a graph with all the individual token transition models.""" -    if not allow_repeats and blank != "optional": -        raise ValueError("Must use blank='optional' if disallowing repeats.") - -    ntoks = len(token_list) -    graph = gtn.Graph(False) - -    # Creating nodes -    graph.add_node(True, True) -    for i in range(ntoks): -        # We can consume one or more consecutive word -        # pieces for each emission: -        # E.g. [ab, ab, ab] transduces to [ab] -        graph.add_node(False, blank != "forced") - -    if blank != "none": -        graph.add_node() - -    # Creating arcs -    if blank != "none": -        # Blank index is assumed to be last (ntoks) -        graph.add_arc(0, ntoks + 1, ntoks, gtn.epsilon) -        graph.add_arc(ntoks + 1, 0, gtn.epsilon) - -    for i in range(ntoks): -        graph.add_arc((ntoks + 1) if blank == "forced" else 0, i + 1, i) -        graph.add_arc(i + 1, i + 1, i, gtn.epsilon) - -        if allow_repeats: -            if blank == "forced": -                # Allow transitions from token to blank only -                graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon) -            else: -                # Allow transition from token to blank and all other tokens -                graph.add_arc(i + 1, 0, gtn.epsilon) - -        else: -            # allow transitions to blank and all other tokens except the same token -            graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon) -            for j in range(ntoks): -                if i != j: -                    graph.add_arc(i + 1, j + 1, j, j) - -    return graph - - -class TransducerLossFunction(torch.autograd.Function): -    @staticmethod -    def forward( -        ctx, -        inputs, -        targets, -        tokens, -        lexicon, -        transition_params=None, -        transitions=None, -        reduction="none", -    ) -> Tensor: -        B, T, C = inputs.shape - -        losses = [None] * B -        emissions_graphs = [None] * B - -        if transitions is not None: -            if transition_params is None: -                raise ValueError("Specified transitions, but not transition params.") - -            cpu_data = transition_params.cpu().contiguous() -            transitions.set_weights(cpu_data.data_ptr()) -            transitions.calc_grad = transition_params.requires_grad -            transitions.zero_grad() - -        def process(b: int) -> None: -            # Create emission graph: -            emissions = gtn.linear_graph(T, C, inputs.requires_grad) -            cpu_data = inputs[b].cpu().contiguous() -            emissions.set_weights(cpu_data.data_ptr()) -            target = make_chain_graph(targets[b]) -            target.arc_sort(True) - -            # Create token tot grapheme decomposition graph -            tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon))) -            tokens_target.arc_sort() - -            # Create alignment graph: -            aligments = gtn.project_input( -                gtn.remove(gtn.compose(tokens, tokens_target)) -            ) -            aligments.arc_sort() - -            # Add transitions scores: -            if transitions is not None: -                aligments = gtn.intersect(transitions, aligments) -                aligments.arc_sort() - -            loss = gtn.forward_score(gtn.intersect(emissions, aligments)) - -            # Normalize if needed: -            if transitions is not None: -                norm = gtn.forward_score(gtn.intersect(emissions, transitions)) -                loss = gtn.subtract(loss, norm) - -            losses[b] = gtn.negate(loss) - -            # Save for backward: -            if emissions.calc_grad: -                emissions_graphs[b] = emissions - -        gtn.parallel_for(process, range(B)) - -        ctx.graphs = (losses, emissions_graphs, transitions) -        ctx.input_shape = inputs.shape - -        # Optionally reduce by target length -        if reduction == "mean": -            scales = [(1 / len(t) if len(t) > 0 else 1.0) for t in targets] -        else: -            scales = [1.0] * B - -        ctx.scales = scales - -        loss = torch.tensor([l.item() * s for l, s in zip(losses, scales)]) -        return torch.mean(loss.to(inputs.device)) - -    @staticmethod -    def backward(ctx, grad_output) -> Tuple: -        losses, emissions_graphs, transitions = ctx.graphs -        scales = ctx.scales - -        B, T, C = ctx.input_shape -        calc_emissions = ctx.needs_input_grad[0] -        input_grad = torch.empty((B, T, C)) if calc_emissions else None - -        def process(b: int) -> None: -            scale = make_scalar_graph(scales[b]) -            gtn.backward(losses[b], scale) -            emissions = emissions_graphs[b] -            if calc_emissions: -                grad = emissions.grad().weights_to_numpy() -                input_grad[b] = torch.tensor(grad).view(1, T, C) - -        gtn.parallel_for(process, range(B)) - -        if calc_emissions: -            input_grad = input_grad.to(grad_output.device) -            input_grad *= grad_output / B - -        if ctx.needs_input_grad[4]: -            grad = transitions.grad().weights_to_numpy() -            transition_grad = torch.tensor(grad).to(grad_output.device) -            transition_grad *= grad_output / B -        else: -            transition_grad = None - -        return ( -            input_grad, -            None,  # target -            None,  # tokens -            None,  # lexicon -            transition_grad,  # transition params -            None,  # transitions graph -            None, -        ) - - -TransducerLoss = TransducerLossFunction.apply - - -class Transducer(nn.Module): -    def __init__( -        self, -        tokens: List, -        graphemes_to_idx: Dict, -        ngram: int = 0, -        transitions: str = None, -        blank: str = "none", -        allow_repeats: bool = True, -        reduction: str = "none", -    ) -> None: -        """A generic transducer loss function. - -        Args: -            tokens (List) : A list of iterable objects (e.g. strings, tuples, etc) -                representing the output tokens of the model (e.g. letters, -                word-pieces, words). For example ["a", "b", "ab", "ba", "aba"] -                could be a list of sub-word tokens. -            graphemes_to_idx (dict) : A dictionary mapping grapheme units (e.g. -                "a", "b", ..) to their corresponding integer index. -            ngram (int) : Order of the token-level transition model. If `ngram=0` -                then no transition model is used. -            blank (string) : Specifies the usage of blank token -                'none' - do not use blank token -                'optional' - allow an optional blank inbetween tokens -                'forced' - force a blank inbetween tokens (also referred to as garbage token) -            allow_repeats (boolean) : If false, then we don't allow paths with -                consecutive tokens in the alignment graph. This keeps the graph -                unambiguous in the sense that the same input cannot transduce to -                different outputs. -        """ -        super().__init__() -        if blank not in ["optional", "forced", "none"]: -            raise ValueError( -                "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']" -            ) -        self.tokens = make_token_graph(tokens, blank=blank, allow_repeats=allow_repeats) -        self.lexicon = make_lexicon_graph(tokens, graphemes_to_idx) -        self.ngram = ngram -        if ngram > 0 and transitions is not None: -            raise ValueError("Only one of ngram and transitions may be specified") - -        if ngram > 0: -            transitions = make_transitions_graph( -                ngram, len(tokens) + int(blank != "none"), True -            ) - -        if transitions is not None: -            self.transitions = transitions -            self.transitions.arc_sort() -            self.transitions_params = nn.Parameter( -                torch.zeros(self.transitions.num_arcs()) -            ) -        else: -            self.transitions = None -            self.transitions_params = None -        self.reduction = reduction - -    def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss: -        TransducerLoss( -            inputs, -            targets, -            self.tokens, -            self.lexicon, -            self.transitions_params, -            self.transitions, -            self.reduction, -        ) - -    def viterbi(self, outputs: Tensor) -> List[Tensor]: -        B, T, C = outputs.shape - -        if self.transitions is not None: -            cpu_data = self.transition_params.cpu().contiguous() -            self.transitions.set_weights(cpu_data.data_ptr()) -            self.transitions.calc_grad = False - -        self.tokens.arc_sort() - -        paths = [None] * B - -        def process(b: int) -> None: -            emissions = gtn.linear_graph(T, C, False) -            cpu_data = outputs[b].cpu().contiguous() -            emissions.set_weights(cpu_data.data_ptr()) - -            if self.transitions is not None: -                full_graph = gtn.intersect(emissions, self.transitions) -            else: -                full_graph = emissions - -            # Find the best path and remove back-off arcs: -            path = gtn.remove(gtn.viterbi_path(full_graph)) - -            # Left compose the viterbi path with the "aligment to token" -            # transducer to get the outputs: -            path = gtn.compose(path, self.tokens) - -            # When there are ambiguous paths (allow_repeats is true), we take -            # the shortest: -            path = gtn.viterbi_path(path) -            path = gtn.remove(gtn.project_output(path)) -            paths[b] = path.labels_to_list() - -        gtn.parallel_for(process, range(B)) -        predictions = [torch.IntTensor(path) for path in paths] -        return predictions - - -def load_transducer_loss( -    num_features: int, -    ngram: int, -    tokens: str, -    lexicon: str, -    transitions: str, -    blank: str, -    allow_repeats: bool, -    prepend_wordsep: bool = False, -    use_words: bool = False, -    data_dir: Optional[Union[str, Path]] = None, -    reduction: str = "mean", -) -> Tuple[Transducer, int]: -    if data_dir is None: -        data_dir = ( -            Path(__file__).resolve().parents[4] / "data" / "raw" / "iam" / "iamdb" -        ) -        logger.debug(f"Using data dir: {data_dir}") -        if not data_dir.exists(): -            raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") -    else: -        data_dir = Path(data_dir) -    processed_path = ( -        Path(__file__).resolve().parents[4] / "data" / "processed" / "iam_lines" -    ) -    tokens_path = processed_path / tokens -    lexicon_path = processed_path / lexicon - -    if transitions is not None: -        transitions = gtn.load(str(processed_path / transitions)) - -    preprocessor = Preprocessor( -        data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep, -    ) - -    num_tokens = preprocessor.num_tokens - -    criterion = Transducer( -        preprocessor.tokens, -        preprocessor.graphemes_to_index, -        ngram=ngram, -        transitions=transitions, -        blank=blank, -        allow_repeats=allow_repeats, -        reduction=reduction, -    ) - -    return criterion, num_tokens + int(blank != "none") |