summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transducer
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/transducer')
-rw-r--r--text_recognizer/networks/transducer/__init__.py3
-rw-r--r--text_recognizer/networks/transducer/tds_conv.py208
-rw-r--r--text_recognizer/networks/transducer/test.py60
-rw-r--r--text_recognizer/networks/transducer/transducer.py410
4 files changed, 681 insertions, 0 deletions
diff --git a/text_recognizer/networks/transducer/__init__.py b/text_recognizer/networks/transducer/__init__.py
new file mode 100644
index 0000000..8c19a01
--- /dev/null
+++ b/text_recognizer/networks/transducer/__init__.py
@@ -0,0 +1,3 @@
+"""Transducer modules."""
+from .tds_conv import TDS2d
+from .transducer import load_transducer_loss, Transducer
diff --git a/text_recognizer/networks/transducer/tds_conv.py b/text_recognizer/networks/transducer/tds_conv.py
new file mode 100644
index 0000000..5fb8ba9
--- /dev/null
+++ b/text_recognizer/networks/transducer/tds_conv.py
@@ -0,0 +1,208 @@
+"""Time-Depth Separable Convolutions.
+
+References:
+ https://arxiv.org/abs/1904.02619
+ https://arxiv.org/pdf/2010.01003.pdf
+
+Code stolen from:
+ https://github.com/facebookresearch/gtn_applications
+
+
+"""
+from typing import List, Tuple
+
+from einops import rearrange
+import gtn
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class TDSBlock2d(nn.Module):
+ """Internal block of a 2D TDSC network."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ img_depth: int,
+ kernel_size: Tuple[int],
+ dropout_rate: float,
+ ) -> None:
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.img_depth = img_depth
+ self.kernel_size = kernel_size
+ self.dropout_rate = dropout_rate
+ self.fc_dim = in_channels * img_depth
+
+ # Network placeholders.
+ self.conv = None
+ self.mlp = None
+ self.instance_norm = None
+
+ self._build_block()
+
+ def _build_block(self) -> None:
+ # Convolutional block.
+ self.conv = nn.Sequential(
+ nn.Conv3d(
+ in_channels=self.in_channels,
+ out_channels=self.in_channels,
+ kernel_size=(1, self.kernel_size[0], self.kernel_size[1]),
+ padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2),
+ ),
+ nn.ReLU(inplace=True),
+ nn.Dropout(self.dropout_rate),
+ )
+
+ # MLP block.
+ self.mlp = nn.Sequential(
+ nn.Linear(self.fc_dim, self.fc_dim),
+ nn.ReLU(inplace=True),
+ nn.Dropout(self.dropout_rate),
+ nn.Linear(self.fc_dim, self.fc_dim),
+ nn.Dropout(self.dropout_rate),
+ )
+
+ # Instance norm.
+ self.instance_norm = nn.ModuleList(
+ [
+ nn.InstanceNorm2d(self.fc_dim, affine=True),
+ nn.InstanceNorm2d(self.fc_dim, affine=True),
+ ]
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass.
+
+ Args:
+ x (Tensor): Input tensor.
+
+ Shape:
+ - x: :math: `(B, CD, H, W)`
+
+ Returns:
+ Tensor: Output tensor.
+
+ """
+ B, CD, H, W = x.shape
+ C, D = self.in_channels, self.img_depth
+ residual = x
+ x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D)
+ x = self.conv(x)
+ x = rearrange(x, "b c d h w -> b (c d) h w")
+ x += residual
+
+ x = self.instance_norm[0](x)
+
+ x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x
+ x + self.instance_norm[1](x)
+
+ # Output shape: [B, CD, H, W]
+ return x
+
+
+class TDS2d(nn.Module):
+ """TDS Netowrk.
+
+ Structure is the following:
+ Downsample layer -> TDS2d group -> ... -> Linear output layer
+
+
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ output_dim: int,
+ depth: int,
+ tds_groups: Tuple[int],
+ kernel_size: Tuple[int],
+ dropout_rate: float,
+ in_channels: int = 1,
+ ) -> None:
+ super().__init__()
+
+ self.in_channels = in_channels
+ self.input_dim = input_dim
+ self.output_dim = output_dim
+ self.depth = depth
+ self.tds_groups = tds_groups
+ self.kernel_size = kernel_size
+ self.dropout_rate = dropout_rate
+
+ self.tds = None
+ self.fc = None
+
+ self._build_network()
+
+ def _build_network(self) -> None:
+ in_channels = self.in_channels
+ modules = []
+ stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups])
+ if self.input_dim % stride_h:
+ raise RuntimeError(
+ f"Image height not divisible by total stride {stride_h}."
+ )
+
+ for tds_group in self.tds_groups:
+ # Add downsample layer.
+ out_channels = self.depth * tds_group["channels"]
+ modules.extend(
+ [
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=self.kernel_size,
+ padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2),
+ stride=tds_group["stride"],
+ ),
+ nn.ReLU(inplace=True),
+ nn.Dropout(self.dropout_rate),
+ nn.InstanceNorm2d(out_channels, affine=True),
+ ]
+ )
+
+ for _ in range(tds_group["num_blocks"]):
+ modules.append(
+ TDSBlock2d(
+ tds_group["channels"],
+ self.depth,
+ self.kernel_size,
+ self.dropout_rate,
+ )
+ )
+
+ in_channels = out_channels
+
+ self.tds = nn.Sequential(*modules)
+ self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Forward pass.
+
+ Args:
+ x (Tensor): Input tensor.
+
+ Shape:
+ - x: :math: `(B, H, W)`
+
+ Returns:
+ Tensor: Output tensor.
+
+ """
+ if len(x.shape) == 4:
+ x = x.squeeze(1) # Squeeze the channel dim away.
+
+ B, H, W = x.shape
+ x = rearrange(
+ x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels
+ )
+ x = self.tds(x)
+
+ # x shape: [B, C, H, W]
+ x = rearrange(x, "b c h w -> b w (c h)")
+
+ return self.fc(x)
diff --git a/text_recognizer/networks/transducer/test.py b/text_recognizer/networks/transducer/test.py
new file mode 100644
index 0000000..cadcecc
--- /dev/null
+++ b/text_recognizer/networks/transducer/test.py
@@ -0,0 +1,60 @@
+import torch
+from torch import nn
+
+from text_recognizer.networks.transducer import load_transducer_loss, Transducer
+import unittest
+
+
+class TestTransducer(unittest.TestCase):
+ def test_viterbi(self):
+ T = 5
+ N = 4
+ B = 2
+
+ # fmt: off
+ emissions1 = torch.tensor((
+ 0, 4, 0, 1,
+ 0, 2, 1, 1,
+ 0, 0, 0, 2,
+ 0, 0, 0, 2,
+ 8, 0, 0, 2,
+ ),
+ dtype=torch.float,
+ ).view(T, N)
+ emissions2 = torch.tensor((
+ 0, 2, 1, 7,
+ 0, 2, 9, 1,
+ 0, 0, 0, 2,
+ 0, 0, 5, 2,
+ 1, 0, 0, 2,
+ ),
+ dtype=torch.float,
+ ).view(T, N)
+ # fmt: on
+
+ # Test without blank:
+ labels = [[1, 3, 0], [3, 2, 3, 2, 3]]
+ transducer = Transducer(
+ tokens=["a", "b", "c", "d"],
+ graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3},
+ blank="none",
+ )
+ emissions = torch.stack([emissions1, emissions2], dim=0)
+ predictions = transducer.viterbi(emissions)
+ self.assertEqual([p.tolist() for p in predictions], labels)
+
+ # Test with blank without repeats:
+ labels = [[1, 0], [2, 2]]
+ transducer = Transducer(
+ tokens=["a", "b", "c"],
+ graphemes_to_idx={"a": 0, "b": 1, "c": 2},
+ blank="optional",
+ allow_repeats=False,
+ )
+ emissions = torch.stack([emissions1, emissions2], dim=0)
+ predictions = transducer.viterbi(emissions)
+ self.assertEqual([p.tolist() for p in predictions], labels)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py
new file mode 100644
index 0000000..d7e3d08
--- /dev/null
+++ b/text_recognizer/networks/transducer/transducer.py
@@ -0,0 +1,410 @@
+"""Transducer and the transducer loss function.py
+
+Stolen from:
+ https://github.com/facebookresearch/gtn_applications/blob/master/transducer.py
+
+"""
+from pathlib import Path
+import itertools
+from typing import Dict, List, Optional, Union, Tuple
+
+from loguru import logger
+import gtn
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.datasets.iam_preprocessor import Preprocessor
+
+
+def make_scalar_graph(weight) -> gtn.Graph:
+ scalar = gtn.Graph()
+ scalar.add_node(True)
+ scalar.add_node(False, True)
+ scalar.add_arc(0, 1, 0, 0, weight)
+ return scalar
+
+
+def make_chain_graph(sequence) -> gtn.Graph:
+ graph = gtn.Graph(False)
+ graph.add_node(True)
+ for i, s in enumerate(sequence):
+ graph.add_node(False, i == (len(sequence) - 1))
+ graph.add_arc(i, i + 1, s)
+ return graph
+
+
+def make_transitions_graph(
+ ngram: int, num_tokens: int, calc_grad: bool = False
+) -> gtn.Graph:
+ transitions = gtn.Graph(calc_grad)
+ transitions.add_node(True, ngram == 1)
+
+ state_map = {(): 0}
+
+ # First build transitions which include <s>:
+ for n in range(1, ngram):
+ for state in itertools.product(range(num_tokens), repeat=n):
+ in_idx = state_map[state[:-1]]
+ out_idx = transitions.add_node(False, ngram == 1)
+ state_map[state] = out_idx
+ transitions.add_arc(in_idx, out_idx, state[-1])
+
+ for state in itertools.product(range(num_tokens), repeat=ngram):
+ state_idx = state_map[state[:-1]]
+ new_state_idx = state_map[state[1:]]
+ # p(state[-1] | state[:-1])
+ transitions.add_arc(state_idx, new_state_idx, state[-1])
+
+ if ngram > 1:
+ # Build transitions which include </s>:
+ end_idx = transitions.add_node(False, True)
+ for in_idx in range(end_idx):
+ transitions.add_arc(in_idx, end_idx, gtn.epsilon)
+
+ return transitions
+
+
+def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph:
+ """Constructs a graph which transduces letters to word pieces."""
+ graph = gtn.Graph(False)
+ graph.add_node(True, True)
+ for i, wp in enumerate(word_pieces):
+ prev = 0
+ for l in wp[:-1]:
+ n = graph.add_node()
+ graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon)
+ prev = n
+ graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i)
+ graph.arc_sort()
+ return graph
+
+
+def make_token_graph(
+ token_list: List, blank: str = "none", allow_repeats: bool = True
+) -> gtn.Graph:
+ """Constructs a graph with all the individual token transition models."""
+ if not allow_repeats and blank != "optional":
+ raise ValueError("Must use blank='optional' if disallowing repeats.")
+
+ ntoks = len(token_list)
+ graph = gtn.Graph(False)
+
+ # Creating nodes
+ graph.add_node(True, True)
+ for i in range(ntoks):
+ # We can consume one or more consecutive word
+ # pieces for each emission:
+ # E.g. [ab, ab, ab] transduces to [ab]
+ graph.add_node(False, blank != "forced")
+
+ if blank != "none":
+ graph.add_node()
+
+ # Creating arcs
+ if blank != "none":
+ # Blank index is assumed to be last (ntoks)
+ graph.add_arc(0, ntoks + 1, ntoks, gtn.epsilon)
+ graph.add_arc(ntoks + 1, 0, gtn.epsilon)
+
+ for i in range(ntoks):
+ graph.add_arc((ntoks + 1) if blank == "forced" else 0, i + 1, i)
+ graph.add_arc(i + 1, i + 1, i, gtn.epsilon)
+
+ if allow_repeats:
+ if blank == "forced":
+ # Allow transitions from token to blank only
+ graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon)
+ else:
+ # Allow transition from token to blank and all other tokens
+ graph.add_arc(i + 1, 0, gtn.epsilon)
+
+ else:
+ # allow transitions to blank and all other tokens except the same token
+ graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon)
+ for j in range(ntoks):
+ if i != j:
+ graph.add_arc(i + 1, j + 1, j, j)
+
+ return graph
+
+
+class TransducerLossFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ inputs,
+ targets,
+ tokens,
+ lexicon,
+ transition_params=None,
+ transitions=None,
+ reduction="none",
+ ) -> Tensor:
+ B, T, C = inputs.shape
+
+ losses = [None] * B
+ emissions_graphs = [None] * B
+
+ if transitions is not None:
+ if transition_params is None:
+ raise ValueError("Specified transitions, but not transition params.")
+
+ cpu_data = transition_params.cpu().contiguous()
+ transitions.set_weights(cpu_data.data_ptr())
+ transitions.calc_grad = transition_params.requires_grad
+ transitions.zero_grad()
+
+ def process(b: int) -> None:
+ # Create emission graph:
+ emissions = gtn.linear_graph(T, C, inputs.requires_grad)
+ cpu_data = inputs[b].cpu().contiguous()
+ emissions.set_weights(cpu_data.data_ptr())
+ target = make_chain_graph(targets[b])
+ target.arc_sort(True)
+
+ # Create token tot grapheme decomposition graph
+ tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon)))
+ tokens_target.arc_sort()
+
+ # Create alignment graph:
+ aligments = gtn.project_input(
+ gtn.remove(gtn.compose(tokens, tokens_target))
+ )
+ aligments.arc_sort()
+
+ # Add transitions scores:
+ if transitions is not None:
+ aligments = gtn.intersect(transitions, aligments)
+ aligments.arc_sort()
+
+ loss = gtn.forward_score(gtn.intersect(emissions, aligments))
+
+ # Normalize if needed:
+ if transitions is not None:
+ norm = gtn.forward_score(gtn.intersect(emissions, transitions))
+ loss = gtn.subtract(loss, norm)
+
+ losses[b] = gtn.negate(loss)
+
+ # Save for backward:
+ if emissions.calc_grad:
+ emissions_graphs[b] = emissions
+
+ gtn.parallel_for(process, range(B))
+
+ ctx.graphs = (losses, emissions_graphs, transitions)
+ ctx.input_shape = inputs.shape
+
+ # Optionally reduce by target length
+ if reduction == "mean":
+ scales = [(1 / len(t) if len(t) > 0 else 1.0) for t in targets]
+ else:
+ scales = [1.0] * B
+
+ ctx.scales = scales
+
+ loss = torch.tensor([l.item() * s for l, s in zip(losses, scales)])
+ return torch.mean(loss.to(inputs.device))
+
+ @staticmethod
+ def backward(ctx, grad_output) -> Tuple:
+ losses, emissions_graphs, transitions = ctx.graphs
+ scales = ctx.scales
+
+ B, T, C = ctx.input_shape
+ calc_emissions = ctx.needs_input_grad[0]
+ input_grad = torch.empty((B, T, C)) if calc_emissions else None
+
+ def process(b: int) -> None:
+ scale = make_scalar_graph(scales[b])
+ gtn.backward(losses[b], scale)
+ emissions = emissions_graphs[b]
+ if calc_emissions:
+ grad = emissions.grad().weights_to_numpy()
+ input_grad[b] = torch.tensor(grad).view(1, T, C)
+
+ gtn.parallel_for(process, range(B))
+
+ if calc_emissions:
+ input_grad = input_grad.to(grad_output.device)
+ input_grad *= grad_output / B
+
+ if ctx.needs_input_grad[4]:
+ grad = transitions.grad().weights_to_numpy()
+ transition_grad = torch.tensor(grad).to(grad_output.device)
+ transition_grad *= grad_output / B
+ else:
+ transition_grad = None
+
+ return (
+ input_grad,
+ None, # target
+ None, # tokens
+ None, # lexicon
+ transition_grad, # transition params
+ None, # transitions graph
+ None,
+ )
+
+
+TransducerLoss = TransducerLossFunction.apply
+
+
+class Transducer(nn.Module):
+ def __init__(
+ self,
+ tokens: List,
+ graphemes_to_idx: Dict,
+ ngram: int = 0,
+ transitions: str = None,
+ blank: str = "none",
+ allow_repeats: bool = True,
+ reduction: str = "none",
+ ) -> None:
+ """A generic transducer loss function.
+
+ Args:
+ tokens (List) : A list of iterable objects (e.g. strings, tuples, etc)
+ representing the output tokens of the model (e.g. letters,
+ word-pieces, words). For example ["a", "b", "ab", "ba", "aba"]
+ could be a list of sub-word tokens.
+ graphemes_to_idx (dict) : A dictionary mapping grapheme units (e.g.
+ "a", "b", ..) to their corresponding integer index.
+ ngram (int) : Order of the token-level transition model. If `ngram=0`
+ then no transition model is used.
+ blank (string) : Specifies the usage of blank token
+ 'none' - do not use blank token
+ 'optional' - allow an optional blank inbetween tokens
+ 'forced' - force a blank inbetween tokens (also referred to as garbage token)
+ allow_repeats (boolean) : If false, then we don't allow paths with
+ consecutive tokens in the alignment graph. This keeps the graph
+ unambiguous in the sense that the same input cannot transduce to
+ different outputs.
+ """
+ super().__init__()
+ if blank not in ["optional", "forced", "none"]:
+ raise ValueError(
+ "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']"
+ )
+ self.tokens = make_token_graph(tokens, blank=blank, allow_repeats=allow_repeats)
+ self.lexicon = make_lexicon_graph(tokens, graphemes_to_idx)
+ self.ngram = ngram
+ if ngram > 0 and transitions is not None:
+ raise ValueError("Only one of ngram and transitions may be specified")
+
+ if ngram > 0:
+ transitions = make_transitions_graph(
+ ngram, len(tokens) + int(blank != "none"), True
+ )
+
+ if transitions is not None:
+ self.transitions = transitions
+ self.transitions.arc_sort()
+ self.transitions_params = nn.Parameter(
+ torch.zeros(self.transitions.num_arcs())
+ )
+ else:
+ self.transitions = None
+ self.transitions_params = None
+ self.reduction = reduction
+
+ def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss:
+ TransducerLoss(
+ inputs,
+ targets,
+ self.tokens,
+ self.lexicon,
+ self.transitions_params,
+ self.transitions,
+ self.reduction,
+ )
+
+ def viterbi(self, outputs: Tensor) -> List[Tensor]:
+ B, T, C = outputs.shape
+
+ if self.transitions is not None:
+ cpu_data = self.transition_params.cpu().contiguous()
+ self.transitions.set_weights(cpu_data.data_ptr())
+ self.transitions.calc_grad = False
+
+ self.tokens.arc_sort()
+
+ paths = [None] * B
+
+ def process(b: int) -> None:
+ emissions = gtn.linear_graph(T, C, False)
+ cpu_data = outputs[b].cpu().contiguous()
+ emissions.set_weights(cpu_data.data_ptr())
+
+ if self.transitions is not None:
+ full_graph = gtn.intersect(emissions, self.transitions)
+ else:
+ full_graph = emissions
+
+ # Find the best path and remove back-off arcs:
+ path = gtn.remove(gtn.viterbi_path(full_graph))
+
+ # Left compose the viterbi path with the "aligment to token"
+ # transducer to get the outputs:
+ path = gtn.compose(path, self.tokens)
+
+ # When there are ambiguous paths (allow_repeats is true), we take
+ # the shortest:
+ path = gtn.viterbi_path(path)
+ path = gtn.remove(gtn.project_output(path))
+ paths[b] = path.labels_to_list()
+
+ gtn.parallel_for(process, range(B))
+ predictions = [torch.IntTensor(path) for path in paths]
+ return predictions
+
+
+def load_transducer_loss(
+ num_features: int,
+ ngram: int,
+ tokens: str,
+ lexicon: str,
+ transitions: str,
+ blank: str,
+ allow_repeats: bool,
+ prepend_wordsep: bool = False,
+ use_words: bool = False,
+ data_dir: Optional[Union[str, Path]] = None,
+ reduction: str = "mean",
+) -> Tuple[Transducer, int]:
+ if data_dir is None:
+ data_dir = (
+ Path(__file__).resolve().parents[4] / "data" / "raw" / "iam" / "iamdb"
+ )
+ logger.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ else:
+ data_dir = Path(data_dir)
+ processed_path = (
+ Path(__file__).resolve().parents[4] / "data" / "processed" / "iam_lines"
+ )
+ tokens_path = processed_path / tokens
+ lexicon_path = processed_path / lexicon
+
+ if transitions is not None:
+ transitions = gtn.load(str(processed_path / transitions))
+
+ preprocessor = Preprocessor(
+ data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep,
+ )
+
+ num_tokens = preprocessor.num_tokens
+
+ criterion = Transducer(
+ preprocessor.tokens,
+ preprocessor.graphemes_to_index,
+ ngram=ngram,
+ transitions=transitions,
+ blank=blank,
+ allow_repeats=allow_repeats,
+ reduction=reduction,
+ )
+
+ return criterion, num_tokens + int(blank != "none")