summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks')
-rw-r--r--src/text_recognizer/networks/__init__.py3
-rw-r--r--src/text_recognizer/networks/cnn_transformer.py4
-rw-r--r--src/text_recognizer/networks/transducer/__init__.py1
-rw-r--r--src/text_recognizer/networks/transducer/tds_conv.py15
-rw-r--r--src/text_recognizer/networks/transducer/test.py60
-rw-r--r--src/text_recognizer/networks/transducer/transducer.py410
6 files changed, 484 insertions, 9 deletions
diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py
index bac5d28..1521355 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/src/text_recognizer/networks/__init__.py
@@ -8,7 +8,7 @@ from .lenet import LeNet
from .metrics import accuracy, cer, wer
from .mlp import MLP
from .residual_network import ResidualNetwork, ResidualNetworkEncoder
-from .transducer import TDS2d
+from .transducer import load_transducer_loss, TDS2d
from .transformer import Transformer
from .unet import UNet
from .util import sliding_window
@@ -28,6 +28,7 @@ __all__ = [
"greedy_decoder",
"MLP",
"LeNet",
+ "load_transducer_loss",
"ResidualNetwork",
"ResidualNetworkEncoder",
"sliding_window",
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/src/text_recognizer/networks/cnn_transformer.py
index 7133c26..a2d7926 100644
--- a/src/text_recognizer/networks/cnn_transformer.py
+++ b/src/text_recognizer/networks/cnn_transformer.py
@@ -112,11 +112,11 @@ class CNNTransformer(nn.Module):
if self.max_pool is not None:
src = self.max_pool(src)
- if self.adaptive_pool is not None:
+ if self.adaptive_pool is not None and len(src.shape) == 4:
src = rearrange(src, "b c h w -> b w c h")
src = self.adaptive_pool(src)
src = src.squeeze(3)
- else:
+ elif len(src.shape) == 4:
src = rearrange(src, "b c h w -> b (h w) c")
b, t, _ = src.shape
diff --git a/src/text_recognizer/networks/transducer/__init__.py b/src/text_recognizer/networks/transducer/__init__.py
index fdd6662..8c19a01 100644
--- a/src/text_recognizer/networks/transducer/__init__.py
+++ b/src/text_recognizer/networks/transducer/__init__.py
@@ -1,2 +1,3 @@
"""Transducer modules."""
from .tds_conv import TDS2d
+from .transducer import load_transducer_loss, Transducer
diff --git a/src/text_recognizer/networks/transducer/tds_conv.py b/src/text_recognizer/networks/transducer/tds_conv.py
index 018caf2..5fb8ba9 100644
--- a/src/text_recognizer/networks/transducer/tds_conv.py
+++ b/src/text_recognizer/networks/transducer/tds_conv.py
@@ -136,8 +136,10 @@ class TDS2d(nn.Module):
self.tds = None
self.fc = None
- def _build_network(self) -> None:
+ self._build_network()
+ def _build_network(self) -> None:
+ in_channels = self.in_channels
modules = []
stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups])
if self.input_dim % stride_h:
@@ -151,7 +153,7 @@ class TDS2d(nn.Module):
modules.extend(
[
nn.Conv2d(
- in_channels=self.in_channels,
+ in_channels=in_channels,
out_channels=out_channels,
kernel_size=self.kernel_size,
padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2),
@@ -173,12 +175,10 @@ class TDS2d(nn.Module):
)
)
- self.in_channels = out_channels
+ in_channels = out_channels
self.tds = nn.Sequential(*modules)
- self.fc = nn.Linear(
- self.in_channels * self.input_dim // stride_h, self.output_dim
- )
+ self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim)
def forward(self, x: Tensor) -> Tensor:
"""Forward pass.
@@ -193,6 +193,9 @@ class TDS2d(nn.Module):
Tensor: Output tensor.
"""
+ if len(x.shape) == 4:
+ x = x.squeeze(1) # Squeeze the channel dim away.
+
B, H, W = x.shape
x = rearrange(
x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels
diff --git a/src/text_recognizer/networks/transducer/test.py b/src/text_recognizer/networks/transducer/test.py
new file mode 100644
index 0000000..cadcecc
--- /dev/null
+++ b/src/text_recognizer/networks/transducer/test.py
@@ -0,0 +1,60 @@
+import torch
+from torch import nn
+
+from text_recognizer.networks.transducer import load_transducer_loss, Transducer
+import unittest
+
+
+class TestTransducer(unittest.TestCase):
+ def test_viterbi(self):
+ T = 5
+ N = 4
+ B = 2
+
+ # fmt: off
+ emissions1 = torch.tensor((
+ 0, 4, 0, 1,
+ 0, 2, 1, 1,
+ 0, 0, 0, 2,
+ 0, 0, 0, 2,
+ 8, 0, 0, 2,
+ ),
+ dtype=torch.float,
+ ).view(T, N)
+ emissions2 = torch.tensor((
+ 0, 2, 1, 7,
+ 0, 2, 9, 1,
+ 0, 0, 0, 2,
+ 0, 0, 5, 2,
+ 1, 0, 0, 2,
+ ),
+ dtype=torch.float,
+ ).view(T, N)
+ # fmt: on
+
+ # Test without blank:
+ labels = [[1, 3, 0], [3, 2, 3, 2, 3]]
+ transducer = Transducer(
+ tokens=["a", "b", "c", "d"],
+ graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3},
+ blank="none",
+ )
+ emissions = torch.stack([emissions1, emissions2], dim=0)
+ predictions = transducer.viterbi(emissions)
+ self.assertEqual([p.tolist() for p in predictions], labels)
+
+ # Test with blank without repeats:
+ labels = [[1, 0], [2, 2]]
+ transducer = Transducer(
+ tokens=["a", "b", "c"],
+ graphemes_to_idx={"a": 0, "b": 1, "c": 2},
+ blank="optional",
+ allow_repeats=False,
+ )
+ emissions = torch.stack([emissions1, emissions2], dim=0)
+ predictions = transducer.viterbi(emissions)
+ self.assertEqual([p.tolist() for p in predictions], labels)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/src/text_recognizer/networks/transducer/transducer.py b/src/text_recognizer/networks/transducer/transducer.py
new file mode 100644
index 0000000..d7e3d08
--- /dev/null
+++ b/src/text_recognizer/networks/transducer/transducer.py
@@ -0,0 +1,410 @@
+"""Transducer and the transducer loss function.py
+
+Stolen from:
+ https://github.com/facebookresearch/gtn_applications/blob/master/transducer.py
+
+"""
+from pathlib import Path
+import itertools
+from typing import Dict, List, Optional, Union, Tuple
+
+from loguru import logger
+import gtn
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.datasets.iam_preprocessor import Preprocessor
+
+
+def make_scalar_graph(weight) -> gtn.Graph:
+ scalar = gtn.Graph()
+ scalar.add_node(True)
+ scalar.add_node(False, True)
+ scalar.add_arc(0, 1, 0, 0, weight)
+ return scalar
+
+
+def make_chain_graph(sequence) -> gtn.Graph:
+ graph = gtn.Graph(False)
+ graph.add_node(True)
+ for i, s in enumerate(sequence):
+ graph.add_node(False, i == (len(sequence) - 1))
+ graph.add_arc(i, i + 1, s)
+ return graph
+
+
+def make_transitions_graph(
+ ngram: int, num_tokens: int, calc_grad: bool = False
+) -> gtn.Graph:
+ transitions = gtn.Graph(calc_grad)
+ transitions.add_node(True, ngram == 1)
+
+ state_map = {(): 0}
+
+ # First build transitions which include <s>:
+ for n in range(1, ngram):
+ for state in itertools.product(range(num_tokens), repeat=n):
+ in_idx = state_map[state[:-1]]
+ out_idx = transitions.add_node(False, ngram == 1)
+ state_map[state] = out_idx
+ transitions.add_arc(in_idx, out_idx, state[-1])
+
+ for state in itertools.product(range(num_tokens), repeat=ngram):
+ state_idx = state_map[state[:-1]]
+ new_state_idx = state_map[state[1:]]
+ # p(state[-1] | state[:-1])
+ transitions.add_arc(state_idx, new_state_idx, state[-1])
+
+ if ngram > 1:
+ # Build transitions which include </s>:
+ end_idx = transitions.add_node(False, True)
+ for in_idx in range(end_idx):
+ transitions.add_arc(in_idx, end_idx, gtn.epsilon)
+
+ return transitions
+
+
+def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph:
+ """Constructs a graph which transduces letters to word pieces."""
+ graph = gtn.Graph(False)
+ graph.add_node(True, True)
+ for i, wp in enumerate(word_pieces):
+ prev = 0
+ for l in wp[:-1]:
+ n = graph.add_node()
+ graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon)
+ prev = n
+ graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i)
+ graph.arc_sort()
+ return graph
+
+
+def make_token_graph(
+ token_list: List, blank: str = "none", allow_repeats: bool = True
+) -> gtn.Graph:
+ """Constructs a graph with all the individual token transition models."""
+ if not allow_repeats and blank != "optional":
+ raise ValueError("Must use blank='optional' if disallowing repeats.")
+
+ ntoks = len(token_list)
+ graph = gtn.Graph(False)
+
+ # Creating nodes
+ graph.add_node(True, True)
+ for i in range(ntoks):
+ # We can consume one or more consecutive word
+ # pieces for each emission:
+ # E.g. [ab, ab, ab] transduces to [ab]
+ graph.add_node(False, blank != "forced")
+
+ if blank != "none":
+ graph.add_node()
+
+ # Creating arcs
+ if blank != "none":
+ # Blank index is assumed to be last (ntoks)
+ graph.add_arc(0, ntoks + 1, ntoks, gtn.epsilon)
+ graph.add_arc(ntoks + 1, 0, gtn.epsilon)
+
+ for i in range(ntoks):
+ graph.add_arc((ntoks + 1) if blank == "forced" else 0, i + 1, i)
+ graph.add_arc(i + 1, i + 1, i, gtn.epsilon)
+
+ if allow_repeats:
+ if blank == "forced":
+ # Allow transitions from token to blank only
+ graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon)
+ else:
+ # Allow transition from token to blank and all other tokens
+ graph.add_arc(i + 1, 0, gtn.epsilon)
+
+ else:
+ # allow transitions to blank and all other tokens except the same token
+ graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon)
+ for j in range(ntoks):
+ if i != j:
+ graph.add_arc(i + 1, j + 1, j, j)
+
+ return graph
+
+
+class TransducerLossFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ inputs,
+ targets,
+ tokens,
+ lexicon,
+ transition_params=None,
+ transitions=None,
+ reduction="none",
+ ) -> Tensor:
+ B, T, C = inputs.shape
+
+ losses = [None] * B
+ emissions_graphs = [None] * B
+
+ if transitions is not None:
+ if transition_params is None:
+ raise ValueError("Specified transitions, but not transition params.")
+
+ cpu_data = transition_params.cpu().contiguous()
+ transitions.set_weights(cpu_data.data_ptr())
+ transitions.calc_grad = transition_params.requires_grad
+ transitions.zero_grad()
+
+ def process(b: int) -> None:
+ # Create emission graph:
+ emissions = gtn.linear_graph(T, C, inputs.requires_grad)
+ cpu_data = inputs[b].cpu().contiguous()
+ emissions.set_weights(cpu_data.data_ptr())
+ target = make_chain_graph(targets[b])
+ target.arc_sort(True)
+
+ # Create token tot grapheme decomposition graph
+ tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon)))
+ tokens_target.arc_sort()
+
+ # Create alignment graph:
+ aligments = gtn.project_input(
+ gtn.remove(gtn.compose(tokens, tokens_target))
+ )
+ aligments.arc_sort()
+
+ # Add transitions scores:
+ if transitions is not None:
+ aligments = gtn.intersect(transitions, aligments)
+ aligments.arc_sort()
+
+ loss = gtn.forward_score(gtn.intersect(emissions, aligments))
+
+ # Normalize if needed:
+ if transitions is not None:
+ norm = gtn.forward_score(gtn.intersect(emissions, transitions))
+ loss = gtn.subtract(loss, norm)
+
+ losses[b] = gtn.negate(loss)
+
+ # Save for backward:
+ if emissions.calc_grad:
+ emissions_graphs[b] = emissions
+
+ gtn.parallel_for(process, range(B))
+
+ ctx.graphs = (losses, emissions_graphs, transitions)
+ ctx.input_shape = inputs.shape
+
+ # Optionally reduce by target length
+ if reduction == "mean":
+ scales = [(1 / len(t) if len(t) > 0 else 1.0) for t in targets]
+ else:
+ scales = [1.0] * B
+
+ ctx.scales = scales
+
+ loss = torch.tensor([l.item() * s for l, s in zip(losses, scales)])
+ return torch.mean(loss.to(inputs.device))
+
+ @staticmethod
+ def backward(ctx, grad_output) -> Tuple:
+ losses, emissions_graphs, transitions = ctx.graphs
+ scales = ctx.scales
+
+ B, T, C = ctx.input_shape
+ calc_emissions = ctx.needs_input_grad[0]
+ input_grad = torch.empty((B, T, C)) if calc_emissions else None
+
+ def process(b: int) -> None:
+ scale = make_scalar_graph(scales[b])
+ gtn.backward(losses[b], scale)
+ emissions = emissions_graphs[b]
+ if calc_emissions:
+ grad = emissions.grad().weights_to_numpy()
+ input_grad[b] = torch.tensor(grad).view(1, T, C)
+
+ gtn.parallel_for(process, range(B))
+
+ if calc_emissions:
+ input_grad = input_grad.to(grad_output.device)
+ input_grad *= grad_output / B
+
+ if ctx.needs_input_grad[4]:
+ grad = transitions.grad().weights_to_numpy()
+ transition_grad = torch.tensor(grad).to(grad_output.device)
+ transition_grad *= grad_output / B
+ else:
+ transition_grad = None
+
+ return (
+ input_grad,
+ None, # target
+ None, # tokens
+ None, # lexicon
+ transition_grad, # transition params
+ None, # transitions graph
+ None,
+ )
+
+
+TransducerLoss = TransducerLossFunction.apply
+
+
+class Transducer(nn.Module):
+ def __init__(
+ self,
+ tokens: List,
+ graphemes_to_idx: Dict,
+ ngram: int = 0,
+ transitions: str = None,
+ blank: str = "none",
+ allow_repeats: bool = True,
+ reduction: str = "none",
+ ) -> None:
+ """A generic transducer loss function.
+
+ Args:
+ tokens (List) : A list of iterable objects (e.g. strings, tuples, etc)
+ representing the output tokens of the model (e.g. letters,
+ word-pieces, words). For example ["a", "b", "ab", "ba", "aba"]
+ could be a list of sub-word tokens.
+ graphemes_to_idx (dict) : A dictionary mapping grapheme units (e.g.
+ "a", "b", ..) to their corresponding integer index.
+ ngram (int) : Order of the token-level transition model. If `ngram=0`
+ then no transition model is used.
+ blank (string) : Specifies the usage of blank token
+ 'none' - do not use blank token
+ 'optional' - allow an optional blank inbetween tokens
+ 'forced' - force a blank inbetween tokens (also referred to as garbage token)
+ allow_repeats (boolean) : If false, then we don't allow paths with
+ consecutive tokens in the alignment graph. This keeps the graph
+ unambiguous in the sense that the same input cannot transduce to
+ different outputs.
+ """
+ super().__init__()
+ if blank not in ["optional", "forced", "none"]:
+ raise ValueError(
+ "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']"
+ )
+ self.tokens = make_token_graph(tokens, blank=blank, allow_repeats=allow_repeats)
+ self.lexicon = make_lexicon_graph(tokens, graphemes_to_idx)
+ self.ngram = ngram
+ if ngram > 0 and transitions is not None:
+ raise ValueError("Only one of ngram and transitions may be specified")
+
+ if ngram > 0:
+ transitions = make_transitions_graph(
+ ngram, len(tokens) + int(blank != "none"), True
+ )
+
+ if transitions is not None:
+ self.transitions = transitions
+ self.transitions.arc_sort()
+ self.transitions_params = nn.Parameter(
+ torch.zeros(self.transitions.num_arcs())
+ )
+ else:
+ self.transitions = None
+ self.transitions_params = None
+ self.reduction = reduction
+
+ def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss:
+ TransducerLoss(
+ inputs,
+ targets,
+ self.tokens,
+ self.lexicon,
+ self.transitions_params,
+ self.transitions,
+ self.reduction,
+ )
+
+ def viterbi(self, outputs: Tensor) -> List[Tensor]:
+ B, T, C = outputs.shape
+
+ if self.transitions is not None:
+ cpu_data = self.transition_params.cpu().contiguous()
+ self.transitions.set_weights(cpu_data.data_ptr())
+ self.transitions.calc_grad = False
+
+ self.tokens.arc_sort()
+
+ paths = [None] * B
+
+ def process(b: int) -> None:
+ emissions = gtn.linear_graph(T, C, False)
+ cpu_data = outputs[b].cpu().contiguous()
+ emissions.set_weights(cpu_data.data_ptr())
+
+ if self.transitions is not None:
+ full_graph = gtn.intersect(emissions, self.transitions)
+ else:
+ full_graph = emissions
+
+ # Find the best path and remove back-off arcs:
+ path = gtn.remove(gtn.viterbi_path(full_graph))
+
+ # Left compose the viterbi path with the "aligment to token"
+ # transducer to get the outputs:
+ path = gtn.compose(path, self.tokens)
+
+ # When there are ambiguous paths (allow_repeats is true), we take
+ # the shortest:
+ path = gtn.viterbi_path(path)
+ path = gtn.remove(gtn.project_output(path))
+ paths[b] = path.labels_to_list()
+
+ gtn.parallel_for(process, range(B))
+ predictions = [torch.IntTensor(path) for path in paths]
+ return predictions
+
+
+def load_transducer_loss(
+ num_features: int,
+ ngram: int,
+ tokens: str,
+ lexicon: str,
+ transitions: str,
+ blank: str,
+ allow_repeats: bool,
+ prepend_wordsep: bool = False,
+ use_words: bool = False,
+ data_dir: Optional[Union[str, Path]] = None,
+ reduction: str = "mean",
+) -> Tuple[Transducer, int]:
+ if data_dir is None:
+ data_dir = (
+ Path(__file__).resolve().parents[4] / "data" / "raw" / "iam" / "iamdb"
+ )
+ logger.debug(f"Using data dir: {data_dir}")
+ if not data_dir.exists():
+ raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
+ else:
+ data_dir = Path(data_dir)
+ processed_path = (
+ Path(__file__).resolve().parents[4] / "data" / "processed" / "iam_lines"
+ )
+ tokens_path = processed_path / tokens
+ lexicon_path = processed_path / lexicon
+
+ if transitions is not None:
+ transitions = gtn.load(str(processed_path / transitions))
+
+ preprocessor = Preprocessor(
+ data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep,
+ )
+
+ num_tokens = preprocessor.num_tokens
+
+ criterion = Transducer(
+ preprocessor.tokens,
+ preprocessor.graphemes_to_index,
+ ngram=ngram,
+ transitions=transitions,
+ blank=blank,
+ allow_repeats=allow_repeats,
+ reduction=reduction,
+ )
+
+ return criterion, num_tokens + int(blank != "none")