From 7e8e54e84c63171e748bbf09516fd517e6821ace Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sat, 20 Mar 2021 18:09:06 +0100
Subject: Inital commit for refactoring to lightning

---
 .../networks/transducer/__init__.py                |   3 -
 .../networks/transducer/tds_conv.py                | 208 -----------
 src/text_recognizer/networks/transducer/test.py    |  60 ---
 .../networks/transducer/transducer.py              | 410 ---------------------
 4 files changed, 681 deletions(-)
 delete mode 100644 src/text_recognizer/networks/transducer/__init__.py
 delete mode 100644 src/text_recognizer/networks/transducer/tds_conv.py
 delete mode 100644 src/text_recognizer/networks/transducer/test.py
 delete mode 100644 src/text_recognizer/networks/transducer/transducer.py

(limited to 'src/text_recognizer/networks/transducer')

diff --git a/src/text_recognizer/networks/transducer/__init__.py b/src/text_recognizer/networks/transducer/__init__.py
deleted file mode 100644
index 8c19a01..0000000
--- a/src/text_recognizer/networks/transducer/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-"""Transducer modules."""
-from .tds_conv import TDS2d
-from .transducer import load_transducer_loss, Transducer
diff --git a/src/text_recognizer/networks/transducer/tds_conv.py b/src/text_recognizer/networks/transducer/tds_conv.py
deleted file mode 100644
index 5fb8ba9..0000000
--- a/src/text_recognizer/networks/transducer/tds_conv.py
+++ /dev/null
@@ -1,208 +0,0 @@
-"""Time-Depth Separable Convolutions.
-
-References:
-    https://arxiv.org/abs/1904.02619
-    https://arxiv.org/pdf/2010.01003.pdf
-
-Code stolen from:
-    https://github.com/facebookresearch/gtn_applications
-
-
-"""
-from typing import List, Tuple
-
-from einops import rearrange
-import gtn
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class TDSBlock2d(nn.Module):
-    """Internal block of a 2D TDSC network."""
-
-    def __init__(
-        self,
-        in_channels: int,
-        img_depth: int,
-        kernel_size: Tuple[int],
-        dropout_rate: float,
-    ) -> None:
-        super().__init__()
-
-        self.in_channels = in_channels
-        self.img_depth = img_depth
-        self.kernel_size = kernel_size
-        self.dropout_rate = dropout_rate
-        self.fc_dim = in_channels * img_depth
-
-        # Network placeholders.
-        self.conv = None
-        self.mlp = None
-        self.instance_norm = None
-
-        self._build_block()
-
-    def _build_block(self) -> None:
-        # Convolutional block.
-        self.conv = nn.Sequential(
-            nn.Conv3d(
-                in_channels=self.in_channels,
-                out_channels=self.in_channels,
-                kernel_size=(1, self.kernel_size[0], self.kernel_size[1]),
-                padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2),
-            ),
-            nn.ReLU(inplace=True),
-            nn.Dropout(self.dropout_rate),
-        )
-
-        # MLP block.
-        self.mlp = nn.Sequential(
-            nn.Linear(self.fc_dim, self.fc_dim),
-            nn.ReLU(inplace=True),
-            nn.Dropout(self.dropout_rate),
-            nn.Linear(self.fc_dim, self.fc_dim),
-            nn.Dropout(self.dropout_rate),
-        )
-
-        # Instance norm.
-        self.instance_norm = nn.ModuleList(
-            [
-                nn.InstanceNorm2d(self.fc_dim, affine=True),
-                nn.InstanceNorm2d(self.fc_dim, affine=True),
-            ]
-        )
-
-    def forward(self, x: Tensor) -> Tensor:
-        """Forward pass.
-
-        Args:
-            x (Tensor): Input tensor.
-
-        Shape:
-            - x: :math: `(B, CD, H, W)`
-
-        Returns:
-            Tensor: Output tensor.
-
-        """
-        B, CD, H, W = x.shape
-        C, D = self.in_channels, self.img_depth
-        residual = x
-        x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D)
-        x = self.conv(x)
-        x = rearrange(x, "b c d h w -> b (c d) h w")
-        x += residual
-
-        x = self.instance_norm[0](x)
-
-        x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x
-        x + self.instance_norm[1](x)
-
-        # Output shape: [B, CD, H, W]
-        return x
-
-
-class TDS2d(nn.Module):
-    """TDS Netowrk.
-
-    Structure is the following:
-        Downsample layer -> TDS2d group -> ... -> Linear output layer
-
-
-    """
-
-    def __init__(
-        self,
-        input_dim: int,
-        output_dim: int,
-        depth: int,
-        tds_groups: Tuple[int],
-        kernel_size: Tuple[int],
-        dropout_rate: float,
-        in_channels: int = 1,
-    ) -> None:
-        super().__init__()
-
-        self.in_channels = in_channels
-        self.input_dim = input_dim
-        self.output_dim = output_dim
-        self.depth = depth
-        self.tds_groups = tds_groups
-        self.kernel_size = kernel_size
-        self.dropout_rate = dropout_rate
-
-        self.tds = None
-        self.fc = None
-
-        self._build_network()
-
-    def _build_network(self) -> None:
-        in_channels = self.in_channels
-        modules = []
-        stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups])
-        if self.input_dim % stride_h:
-            raise RuntimeError(
-                f"Image height not divisible by total stride {stride_h}."
-            )
-
-        for tds_group in self.tds_groups:
-            # Add downsample layer.
-            out_channels = self.depth * tds_group["channels"]
-            modules.extend(
-                [
-                    nn.Conv2d(
-                        in_channels=in_channels,
-                        out_channels=out_channels,
-                        kernel_size=self.kernel_size,
-                        padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2),
-                        stride=tds_group["stride"],
-                    ),
-                    nn.ReLU(inplace=True),
-                    nn.Dropout(self.dropout_rate),
-                    nn.InstanceNorm2d(out_channels, affine=True),
-                ]
-            )
-
-            for _ in range(tds_group["num_blocks"]):
-                modules.append(
-                    TDSBlock2d(
-                        tds_group["channels"],
-                        self.depth,
-                        self.kernel_size,
-                        self.dropout_rate,
-                    )
-                )
-
-            in_channels = out_channels
-
-        self.tds = nn.Sequential(*modules)
-        self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim)
-
-    def forward(self, x: Tensor) -> Tensor:
-        """Forward pass.
-
-        Args:
-            x (Tensor): Input tensor.
-
-        Shape:
-            - x: :math: `(B, H, W)`
-
-        Returns:
-            Tensor: Output tensor.
-
-        """
-        if len(x.shape) == 4:
-            x = x.squeeze(1)  # Squeeze the channel dim away.
-
-        B, H, W = x.shape
-        x = rearrange(
-            x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels
-        )
-        x = self.tds(x)
-
-        # x shape: [B, C, H, W]
-        x = rearrange(x, "b c h w -> b w (c h)")
-
-        return self.fc(x)
diff --git a/src/text_recognizer/networks/transducer/test.py b/src/text_recognizer/networks/transducer/test.py
deleted file mode 100644
index cadcecc..0000000
--- a/src/text_recognizer/networks/transducer/test.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import torch
-from torch import nn
-
-from text_recognizer.networks.transducer import load_transducer_loss, Transducer
-import unittest
-
-
-class TestTransducer(unittest.TestCase):
-    def test_viterbi(self):
-        T = 5
-        N = 4
-        B = 2
-
-        # fmt: off
-        emissions1 = torch.tensor((
-            0, 4, 0, 1,
-            0, 2, 1, 1,
-            0, 0, 0, 2,
-            0, 0, 0, 2,
-            8, 0, 0, 2,
-            ),
-            dtype=torch.float,
-        ).view(T, N)
-        emissions2 = torch.tensor((
-            0, 2, 1, 7,
-            0, 2, 9, 1,
-            0, 0, 0, 2,
-            0, 0, 5, 2,
-            1, 0, 0, 2,
-            ),
-            dtype=torch.float,
-        ).view(T, N)
-        # fmt: on
-
-        # Test without blank:
-        labels = [[1, 3, 0], [3, 2, 3, 2, 3]]
-        transducer = Transducer(
-            tokens=["a", "b", "c", "d"],
-            graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3},
-            blank="none",
-        )
-        emissions = torch.stack([emissions1, emissions2], dim=0)
-        predictions = transducer.viterbi(emissions)
-        self.assertEqual([p.tolist() for p in predictions], labels)
-
-        # Test with blank without repeats:
-        labels = [[1, 0], [2, 2]]
-        transducer = Transducer(
-            tokens=["a", "b", "c"],
-            graphemes_to_idx={"a": 0, "b": 1, "c": 2},
-            blank="optional",
-            allow_repeats=False,
-        )
-        emissions = torch.stack([emissions1, emissions2], dim=0)
-        predictions = transducer.viterbi(emissions)
-        self.assertEqual([p.tolist() for p in predictions], labels)
-
-
-if __name__ == "__main__":
-    unittest.main()
diff --git a/src/text_recognizer/networks/transducer/transducer.py b/src/text_recognizer/networks/transducer/transducer.py
deleted file mode 100644
index d7e3d08..0000000
--- a/src/text_recognizer/networks/transducer/transducer.py
+++ /dev/null
@@ -1,410 +0,0 @@
-"""Transducer and the transducer loss function.py
-
-Stolen from:
-    https://github.com/facebookresearch/gtn_applications/blob/master/transducer.py
-
-"""
-from pathlib import Path
-import itertools
-from typing import Dict, List, Optional, Union, Tuple
-
-from loguru import logger
-import gtn
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.datasets.iam_preprocessor import Preprocessor
-
-
-def make_scalar_graph(weight) -> gtn.Graph:
-    scalar = gtn.Graph()
-    scalar.add_node(True)
-    scalar.add_node(False, True)
-    scalar.add_arc(0, 1, 0, 0, weight)
-    return scalar
-
-
-def make_chain_graph(sequence) -> gtn.Graph:
-    graph = gtn.Graph(False)
-    graph.add_node(True)
-    for i, s in enumerate(sequence):
-        graph.add_node(False, i == (len(sequence) - 1))
-        graph.add_arc(i, i + 1, s)
-    return graph
-
-
-def make_transitions_graph(
-    ngram: int, num_tokens: int, calc_grad: bool = False
-) -> gtn.Graph:
-    transitions = gtn.Graph(calc_grad)
-    transitions.add_node(True, ngram == 1)
-
-    state_map = {(): 0}
-
-    # First build transitions which include <s>:
-    for n in range(1, ngram):
-        for state in itertools.product(range(num_tokens), repeat=n):
-            in_idx = state_map[state[:-1]]
-            out_idx = transitions.add_node(False, ngram == 1)
-            state_map[state] = out_idx
-            transitions.add_arc(in_idx, out_idx, state[-1])
-
-    for state in itertools.product(range(num_tokens), repeat=ngram):
-        state_idx = state_map[state[:-1]]
-        new_state_idx = state_map[state[1:]]
-        # p(state[-1] | state[:-1])
-        transitions.add_arc(state_idx, new_state_idx, state[-1])
-
-    if ngram > 1:
-        # Build transitions which include </s>:
-        end_idx = transitions.add_node(False, True)
-        for in_idx in range(end_idx):
-            transitions.add_arc(in_idx, end_idx, gtn.epsilon)
-
-    return transitions
-
-
-def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph:
-    """Constructs a graph which transduces letters to word pieces."""
-    graph = gtn.Graph(False)
-    graph.add_node(True, True)
-    for i, wp in enumerate(word_pieces):
-        prev = 0
-        for l in wp[:-1]:
-            n = graph.add_node()
-            graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon)
-            prev = n
-        graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i)
-    graph.arc_sort()
-    return graph
-
-
-def make_token_graph(
-    token_list: List, blank: str = "none", allow_repeats: bool = True
-) -> gtn.Graph:
-    """Constructs a graph with all the individual token transition models."""
-    if not allow_repeats and blank != "optional":
-        raise ValueError("Must use blank='optional' if disallowing repeats.")
-
-    ntoks = len(token_list)
-    graph = gtn.Graph(False)
-
-    # Creating nodes
-    graph.add_node(True, True)
-    for i in range(ntoks):
-        # We can consume one or more consecutive word
-        # pieces for each emission:
-        # E.g. [ab, ab, ab] transduces to [ab]
-        graph.add_node(False, blank != "forced")
-
-    if blank != "none":
-        graph.add_node()
-
-    # Creating arcs
-    if blank != "none":
-        # Blank index is assumed to be last (ntoks)
-        graph.add_arc(0, ntoks + 1, ntoks, gtn.epsilon)
-        graph.add_arc(ntoks + 1, 0, gtn.epsilon)
-
-    for i in range(ntoks):
-        graph.add_arc((ntoks + 1) if blank == "forced" else 0, i + 1, i)
-        graph.add_arc(i + 1, i + 1, i, gtn.epsilon)
-
-        if allow_repeats:
-            if blank == "forced":
-                # Allow transitions from token to blank only
-                graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon)
-            else:
-                # Allow transition from token to blank and all other tokens
-                graph.add_arc(i + 1, 0, gtn.epsilon)
-
-        else:
-            # allow transitions to blank and all other tokens except the same token
-            graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon)
-            for j in range(ntoks):
-                if i != j:
-                    graph.add_arc(i + 1, j + 1, j, j)
-
-    return graph
-
-
-class TransducerLossFunction(torch.autograd.Function):
-    @staticmethod
-    def forward(
-        ctx,
-        inputs,
-        targets,
-        tokens,
-        lexicon,
-        transition_params=None,
-        transitions=None,
-        reduction="none",
-    ) -> Tensor:
-        B, T, C = inputs.shape
-
-        losses = [None] * B
-        emissions_graphs = [None] * B
-
-        if transitions is not None:
-            if transition_params is None:
-                raise ValueError("Specified transitions, but not transition params.")
-
-            cpu_data = transition_params.cpu().contiguous()
-            transitions.set_weights(cpu_data.data_ptr())
-            transitions.calc_grad = transition_params.requires_grad
-            transitions.zero_grad()
-
-        def process(b: int) -> None:
-            # Create emission graph:
-            emissions = gtn.linear_graph(T, C, inputs.requires_grad)
-            cpu_data = inputs[b].cpu().contiguous()
-            emissions.set_weights(cpu_data.data_ptr())
-            target = make_chain_graph(targets[b])
-            target.arc_sort(True)
-
-            # Create token tot grapheme decomposition graph
-            tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon)))
-            tokens_target.arc_sort()
-
-            # Create alignment graph:
-            aligments = gtn.project_input(
-                gtn.remove(gtn.compose(tokens, tokens_target))
-            )
-            aligments.arc_sort()
-
-            # Add transitions scores:
-            if transitions is not None:
-                aligments = gtn.intersect(transitions, aligments)
-                aligments.arc_sort()
-
-            loss = gtn.forward_score(gtn.intersect(emissions, aligments))
-
-            # Normalize if needed:
-            if transitions is not None:
-                norm = gtn.forward_score(gtn.intersect(emissions, transitions))
-                loss = gtn.subtract(loss, norm)
-
-            losses[b] = gtn.negate(loss)
-
-            # Save for backward:
-            if emissions.calc_grad:
-                emissions_graphs[b] = emissions
-
-        gtn.parallel_for(process, range(B))
-
-        ctx.graphs = (losses, emissions_graphs, transitions)
-        ctx.input_shape = inputs.shape
-
-        # Optionally reduce by target length
-        if reduction == "mean":
-            scales = [(1 / len(t) if len(t) > 0 else 1.0) for t in targets]
-        else:
-            scales = [1.0] * B
-
-        ctx.scales = scales
-
-        loss = torch.tensor([l.item() * s for l, s in zip(losses, scales)])
-        return torch.mean(loss.to(inputs.device))
-
-    @staticmethod
-    def backward(ctx, grad_output) -> Tuple:
-        losses, emissions_graphs, transitions = ctx.graphs
-        scales = ctx.scales
-
-        B, T, C = ctx.input_shape
-        calc_emissions = ctx.needs_input_grad[0]
-        input_grad = torch.empty((B, T, C)) if calc_emissions else None
-
-        def process(b: int) -> None:
-            scale = make_scalar_graph(scales[b])
-            gtn.backward(losses[b], scale)
-            emissions = emissions_graphs[b]
-            if calc_emissions:
-                grad = emissions.grad().weights_to_numpy()
-                input_grad[b] = torch.tensor(grad).view(1, T, C)
-
-        gtn.parallel_for(process, range(B))
-
-        if calc_emissions:
-            input_grad = input_grad.to(grad_output.device)
-            input_grad *= grad_output / B
-
-        if ctx.needs_input_grad[4]:
-            grad = transitions.grad().weights_to_numpy()
-            transition_grad = torch.tensor(grad).to(grad_output.device)
-            transition_grad *= grad_output / B
-        else:
-            transition_grad = None
-
-        return (
-            input_grad,
-            None,  # target
-            None,  # tokens
-            None,  # lexicon
-            transition_grad,  # transition params
-            None,  # transitions graph
-            None,
-        )
-
-
-TransducerLoss = TransducerLossFunction.apply
-
-
-class Transducer(nn.Module):
-    def __init__(
-        self,
-        tokens: List,
-        graphemes_to_idx: Dict,
-        ngram: int = 0,
-        transitions: str = None,
-        blank: str = "none",
-        allow_repeats: bool = True,
-        reduction: str = "none",
-    ) -> None:
-        """A generic transducer loss function.
-
-        Args:
-            tokens (List) : A list of iterable objects (e.g. strings, tuples, etc)
-                representing the output tokens of the model (e.g. letters,
-                word-pieces, words). For example ["a", "b", "ab", "ba", "aba"]
-                could be a list of sub-word tokens.
-            graphemes_to_idx (dict) : A dictionary mapping grapheme units (e.g.
-                "a", "b", ..) to their corresponding integer index.
-            ngram (int) : Order of the token-level transition model. If `ngram=0`
-                then no transition model is used.
-            blank (string) : Specifies the usage of blank token
-                'none' - do not use blank token
-                'optional' - allow an optional blank inbetween tokens
-                'forced' - force a blank inbetween tokens (also referred to as garbage token)
-            allow_repeats (boolean) : If false, then we don't allow paths with
-                consecutive tokens in the alignment graph. This keeps the graph
-                unambiguous in the sense that the same input cannot transduce to
-                different outputs.
-        """
-        super().__init__()
-        if blank not in ["optional", "forced", "none"]:
-            raise ValueError(
-                "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']"
-            )
-        self.tokens = make_token_graph(tokens, blank=blank, allow_repeats=allow_repeats)
-        self.lexicon = make_lexicon_graph(tokens, graphemes_to_idx)
-        self.ngram = ngram
-        if ngram > 0 and transitions is not None:
-            raise ValueError("Only one of ngram and transitions may be specified")
-
-        if ngram > 0:
-            transitions = make_transitions_graph(
-                ngram, len(tokens) + int(blank != "none"), True
-            )
-
-        if transitions is not None:
-            self.transitions = transitions
-            self.transitions.arc_sort()
-            self.transitions_params = nn.Parameter(
-                torch.zeros(self.transitions.num_arcs())
-            )
-        else:
-            self.transitions = None
-            self.transitions_params = None
-        self.reduction = reduction
-
-    def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss:
-        TransducerLoss(
-            inputs,
-            targets,
-            self.tokens,
-            self.lexicon,
-            self.transitions_params,
-            self.transitions,
-            self.reduction,
-        )
-
-    def viterbi(self, outputs: Tensor) -> List[Tensor]:
-        B, T, C = outputs.shape
-
-        if self.transitions is not None:
-            cpu_data = self.transition_params.cpu().contiguous()
-            self.transitions.set_weights(cpu_data.data_ptr())
-            self.transitions.calc_grad = False
-
-        self.tokens.arc_sort()
-
-        paths = [None] * B
-
-        def process(b: int) -> None:
-            emissions = gtn.linear_graph(T, C, False)
-            cpu_data = outputs[b].cpu().contiguous()
-            emissions.set_weights(cpu_data.data_ptr())
-
-            if self.transitions is not None:
-                full_graph = gtn.intersect(emissions, self.transitions)
-            else:
-                full_graph = emissions
-
-            # Find the best path and remove back-off arcs:
-            path = gtn.remove(gtn.viterbi_path(full_graph))
-
-            # Left compose the viterbi path with the "aligment to token"
-            # transducer to get the outputs:
-            path = gtn.compose(path, self.tokens)
-
-            # When there are ambiguous paths (allow_repeats is true), we take
-            # the shortest:
-            path = gtn.viterbi_path(path)
-            path = gtn.remove(gtn.project_output(path))
-            paths[b] = path.labels_to_list()
-
-        gtn.parallel_for(process, range(B))
-        predictions = [torch.IntTensor(path) for path in paths]
-        return predictions
-
-
-def load_transducer_loss(
-    num_features: int,
-    ngram: int,
-    tokens: str,
-    lexicon: str,
-    transitions: str,
-    blank: str,
-    allow_repeats: bool,
-    prepend_wordsep: bool = False,
-    use_words: bool = False,
-    data_dir: Optional[Union[str, Path]] = None,
-    reduction: str = "mean",
-) -> Tuple[Transducer, int]:
-    if data_dir is None:
-        data_dir = (
-            Path(__file__).resolve().parents[4] / "data" / "raw" / "iam" / "iamdb"
-        )
-        logger.debug(f"Using data dir: {data_dir}")
-        if not data_dir.exists():
-            raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
-    else:
-        data_dir = Path(data_dir)
-    processed_path = (
-        Path(__file__).resolve().parents[4] / "data" / "processed" / "iam_lines"
-    )
-    tokens_path = processed_path / tokens
-    lexicon_path = processed_path / lexicon
-
-    if transitions is not None:
-        transitions = gtn.load(str(processed_path / transitions))
-
-    preprocessor = Preprocessor(
-        data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep,
-    )
-
-    num_tokens = preprocessor.num_tokens
-
-    criterion = Transducer(
-        preprocessor.tokens,
-        preprocessor.graphemes_to_index,
-        ngram=ngram,
-        transitions=transitions,
-        blank=blank,
-        allow_repeats=allow_repeats,
-        reduction=reduction,
-    )
-
-    return criterion, num_tokens + int(blank != "none")
-- 
cgit v1.2.3-70-g09d2