summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/transducer
diff options
context:
space:
mode:
Diffstat (limited to 'src/text_recognizer/networks/transducer')
-rw-r--r--src/text_recognizer/networks/transducer/__init__.py3
-rw-r--r--src/text_recognizer/networks/transducer/tds_conv.py208
-rw-r--r--src/text_recognizer/networks/transducer/test.py60
-rw-r--r--src/text_recognizer/networks/transducer/transducer.py410
4 files changed, 0 insertions, 681 deletions
diff --git a/src/text_recognizer/networks/transducer/__init__.py b/src/text_recognizer/networks/transducer/__init__.py
deleted file mode 100644
index 8c19a01..0000000
--- a/src/text_recognizer/networks/transducer/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-"""Transducer modules."""
-from .tds_conv import TDS2d
-from .transducer import load_transducer_loss, Transducer
diff --git a/src/text_recognizer/networks/transducer/tds_conv.py b/src/text_recognizer/networks/transducer/tds_conv.py
deleted file mode 100644
index 5fb8ba9..0000000
--- a/src/text_recognizer/networks/transducer/tds_conv.py
+++ /dev/null
@@ -1,208 +0,0 @@
-"""Time-Depth Separable Convolutions.
-
-References:
- https://arxiv.org/abs/1904.02619
- https://arxiv.org/pdf/2010.01003.pdf
-
-Code stolen from:
- https://github.com/facebookresearch/gtn_applications
-
-
-"""
-from typing import List, Tuple
-
-from einops import rearrange
-import gtn
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class TDSBlock2d(nn.Module):
- """Internal block of a 2D TDSC network."""
-
- def __init__(
- self,
- in_channels: int,
- img_depth: int,
- kernel_size: Tuple[int],
- dropout_rate: float,
- ) -> None:
- super().__init__()
-
- self.in_channels = in_channels
- self.img_depth = img_depth
- self.kernel_size = kernel_size
- self.dropout_rate = dropout_rate
- self.fc_dim = in_channels * img_depth
-
- # Network placeholders.
- self.conv = None
- self.mlp = None
- self.instance_norm = None
-
- self._build_block()
-
- def _build_block(self) -> None:
- # Convolutional block.
- self.conv = nn.Sequential(
- nn.Conv3d(
- in_channels=self.in_channels,
- out_channels=self.in_channels,
- kernel_size=(1, self.kernel_size[0], self.kernel_size[1]),
- padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2),
- ),
- nn.ReLU(inplace=True),
- nn.Dropout(self.dropout_rate),
- )
-
- # MLP block.
- self.mlp = nn.Sequential(
- nn.Linear(self.fc_dim, self.fc_dim),
- nn.ReLU(inplace=True),
- nn.Dropout(self.dropout_rate),
- nn.Linear(self.fc_dim, self.fc_dim),
- nn.Dropout(self.dropout_rate),
- )
-
- # Instance norm.
- self.instance_norm = nn.ModuleList(
- [
- nn.InstanceNorm2d(self.fc_dim, affine=True),
- nn.InstanceNorm2d(self.fc_dim, affine=True),
- ]
- )
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass.
-
- Args:
- x (Tensor): Input tensor.
-
- Shape:
- - x: :math: `(B, CD, H, W)`
-
- Returns:
- Tensor: Output tensor.
-
- """
- B, CD, H, W = x.shape
- C, D = self.in_channels, self.img_depth
- residual = x
- x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D)
- x = self.conv(x)
- x = rearrange(x, "b c d h w -> b (c d) h w")
- x += residual
-
- x = self.instance_norm[0](x)
-
- x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x
- x + self.instance_norm[1](x)
-
- # Output shape: [B, CD, H, W]
- return x
-
-
-class TDS2d(nn.Module):
- """TDS Netowrk.
-
- Structure is the following:
- Downsample layer -> TDS2d group -> ... -> Linear output layer
-
-
- """
-
- def __init__(
- self,
- input_dim: int,
- output_dim: int,
- depth: int,
- tds_groups: Tuple[int],
- kernel_size: Tuple[int],
- dropout_rate: float,
- in_channels: int = 1,
- ) -> None:
- super().__init__()
-
- self.in_channels = in_channels
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.depth = depth
- self.tds_groups = tds_groups
- self.kernel_size = kernel_size
- self.dropout_rate = dropout_rate
-
- self.tds = None
- self.fc = None
-
- self._build_network()
-
- def _build_network(self) -> None:
- in_channels = self.in_channels
- modules = []
- stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups])
- if self.input_dim % stride_h:
- raise RuntimeError(
- f"Image height not divisible by total stride {stride_h}."
- )
-
- for tds_group in self.tds_groups:
- # Add downsample layer.
- out_channels = self.depth * tds_group["channels"]
- modules.extend(
- [
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=self.kernel_size,
- padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2),
- stride=tds_group["stride"],
- ),
- nn.ReLU(inplace=True),
- nn.Dropout(self.dropout_rate),
- nn.InstanceNorm2d(out_channels, affine=True),
- ]
- )
-
- for _ in range(tds_group["num_blocks"]):
- modules.append(
- TDSBlock2d(
- tds_group["channels"],
- self.depth,
- self.kernel_size,
- self.dropout_rate,
- )
- )
-
- in_channels = out_channels
-
- self.tds = nn.Sequential(*modules)
- self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim)
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass.
-
- Args:
- x (Tensor): Input tensor.
-
- Shape:
- - x: :math: `(B, H, W)`
-
- Returns:
- Tensor: Output tensor.
-
- """
- if len(x.shape) == 4:
- x = x.squeeze(1) # Squeeze the channel dim away.
-
- B, H, W = x.shape
- x = rearrange(
- x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels
- )
- x = self.tds(x)
-
- # x shape: [B, C, H, W]
- x = rearrange(x, "b c h w -> b w (c h)")
-
- return self.fc(x)
diff --git a/src/text_recognizer/networks/transducer/test.py b/src/text_recognizer/networks/transducer/test.py
deleted file mode 100644
index cadcecc..0000000
--- a/src/text_recognizer/networks/transducer/test.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import torch
-from torch import nn
-
-from text_recognizer.networks.transducer import load_transducer_loss, Transducer
-import unittest
-
-
-class TestTransducer(unittest.TestCase):
- def test_viterbi(self):
- T = 5
- N = 4
- B = 2
-
- # fmt: off
- emissions1 = torch.tensor((
- 0, 4, 0, 1,
- 0, 2, 1, 1,
- 0, 0, 0, 2,
- 0, 0, 0, 2,
- 8, 0, 0, 2,
- ),
- dtype=torch.float,
- ).view(T, N)
- emissions2 = torch.tensor((
- 0, 2, 1, 7,
- 0, 2, 9, 1,
- 0, 0, 0, 2,
- 0, 0, 5, 2,
- 1, 0, 0, 2,
- ),
- dtype=torch.float,
- ).view(T, N)
- # fmt: on
-
- # Test without blank:
- labels = [[1, 3, 0], [3, 2, 3, 2, 3]]
- transducer = Transducer(
- tokens=["a", "b", "c", "d"],
- graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3},
- blank="none",
- )
- emissions = torch.stack([emissions1, emissions2], dim=0)
- predictions = transducer.viterbi(emissions)
- self.assertEqual([p.tolist() for p in predictions], labels)
-
- # Test with blank without repeats:
- labels = [[1, 0], [2, 2]]
- transducer = Transducer(
- tokens=["a", "b", "c"],
- graphemes_to_idx={"a": 0, "b": 1, "c": 2},
- blank="optional",
- allow_repeats=False,
- )
- emissions = torch.stack([emissions1, emissions2], dim=0)
- predictions = transducer.viterbi(emissions)
- self.assertEqual([p.tolist() for p in predictions], labels)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/src/text_recognizer/networks/transducer/transducer.py b/src/text_recognizer/networks/transducer/transducer.py
deleted file mode 100644
index d7e3d08..0000000
--- a/src/text_recognizer/networks/transducer/transducer.py
+++ /dev/null
@@ -1,410 +0,0 @@
-"""Transducer and the transducer loss function.py
-
-Stolen from:
- https://github.com/facebookresearch/gtn_applications/blob/master/transducer.py
-
-"""
-from pathlib import Path
-import itertools
-from typing import Dict, List, Optional, Union, Tuple
-
-from loguru import logger
-import gtn
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.datasets.iam_preprocessor import Preprocessor
-
-
-def make_scalar_graph(weight) -> gtn.Graph:
- scalar = gtn.Graph()
- scalar.add_node(True)
- scalar.add_node(False, True)
- scalar.add_arc(0, 1, 0, 0, weight)
- return scalar
-
-
-def make_chain_graph(sequence) -> gtn.Graph:
- graph = gtn.Graph(False)
- graph.add_node(True)
- for i, s in enumerate(sequence):
- graph.add_node(False, i == (len(sequence) - 1))
- graph.add_arc(i, i + 1, s)
- return graph
-
-
-def make_transitions_graph(
- ngram: int, num_tokens: int, calc_grad: bool = False
-) -> gtn.Graph:
- transitions = gtn.Graph(calc_grad)
- transitions.add_node(True, ngram == 1)
-
- state_map = {(): 0}
-
- # First build transitions which include <s>:
- for n in range(1, ngram):
- for state in itertools.product(range(num_tokens), repeat=n):
- in_idx = state_map[state[:-1]]
- out_idx = transitions.add_node(False, ngram == 1)
- state_map[state] = out_idx
- transitions.add_arc(in_idx, out_idx, state[-1])
-
- for state in itertools.product(range(num_tokens), repeat=ngram):
- state_idx = state_map[state[:-1]]
- new_state_idx = state_map[state[1:]]
- # p(state[-1] | state[:-1])
- transitions.add_arc(state_idx, new_state_idx, state[-1])
-
- if ngram > 1:
- # Build transitions which include </s>:
- end_idx = transitions.add_node(False, True)
- for in_idx in range(end_idx):
- transitions.add_arc(in_idx, end_idx, gtn.epsilon)
-
- return transitions
-
-
-def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph:
- """Constructs a graph which transduces letters to word pieces."""
- graph = gtn.Graph(False)
- graph.add_node(True, True)
- for i, wp in enumerate(word_pieces):
- prev = 0
- for l in wp[:-1]:
- n = graph.add_node()
- graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon)
- prev = n
- graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i)
- graph.arc_sort()
- return graph
-
-
-def make_token_graph(
- token_list: List, blank: str = "none", allow_repeats: bool = True
-) -> gtn.Graph:
- """Constructs a graph with all the individual token transition models."""
- if not allow_repeats and blank != "optional":
- raise ValueError("Must use blank='optional' if disallowing repeats.")
-
- ntoks = len(token_list)
- graph = gtn.Graph(False)
-
- # Creating nodes
- graph.add_node(True, True)
- for i in range(ntoks):
- # We can consume one or more consecutive word
- # pieces for each emission:
- # E.g. [ab, ab, ab] transduces to [ab]
- graph.add_node(False, blank != "forced")
-
- if blank != "none":
- graph.add_node()
-
- # Creating arcs
- if blank != "none":
- # Blank index is assumed to be last (ntoks)
- graph.add_arc(0, ntoks + 1, ntoks, gtn.epsilon)
- graph.add_arc(ntoks + 1, 0, gtn.epsilon)
-
- for i in range(ntoks):
- graph.add_arc((ntoks + 1) if blank == "forced" else 0, i + 1, i)
- graph.add_arc(i + 1, i + 1, i, gtn.epsilon)
-
- if allow_repeats:
- if blank == "forced":
- # Allow transitions from token to blank only
- graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon)
- else:
- # Allow transition from token to blank and all other tokens
- graph.add_arc(i + 1, 0, gtn.epsilon)
-
- else:
- # allow transitions to blank and all other tokens except the same token
- graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon)
- for j in range(ntoks):
- if i != j:
- graph.add_arc(i + 1, j + 1, j, j)
-
- return graph
-
-
-class TransducerLossFunction(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- inputs,
- targets,
- tokens,
- lexicon,
- transition_params=None,
- transitions=None,
- reduction="none",
- ) -> Tensor:
- B, T, C = inputs.shape
-
- losses = [None] * B
- emissions_graphs = [None] * B
-
- if transitions is not None:
- if transition_params is None:
- raise ValueError("Specified transitions, but not transition params.")
-
- cpu_data = transition_params.cpu().contiguous()
- transitions.set_weights(cpu_data.data_ptr())
- transitions.calc_grad = transition_params.requires_grad
- transitions.zero_grad()
-
- def process(b: int) -> None:
- # Create emission graph:
- emissions = gtn.linear_graph(T, C, inputs.requires_grad)
- cpu_data = inputs[b].cpu().contiguous()
- emissions.set_weights(cpu_data.data_ptr())
- target = make_chain_graph(targets[b])
- target.arc_sort(True)
-
- # Create token tot grapheme decomposition graph
- tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon)))
- tokens_target.arc_sort()
-
- # Create alignment graph:
- aligments = gtn.project_input(
- gtn.remove(gtn.compose(tokens, tokens_target))
- )
- aligments.arc_sort()
-
- # Add transitions scores:
- if transitions is not None:
- aligments = gtn.intersect(transitions, aligments)
- aligments.arc_sort()
-
- loss = gtn.forward_score(gtn.intersect(emissions, aligments))
-
- # Normalize if needed:
- if transitions is not None:
- norm = gtn.forward_score(gtn.intersect(emissions, transitions))
- loss = gtn.subtract(loss, norm)
-
- losses[b] = gtn.negate(loss)
-
- # Save for backward:
- if emissions.calc_grad:
- emissions_graphs[b] = emissions
-
- gtn.parallel_for(process, range(B))
-
- ctx.graphs = (losses, emissions_graphs, transitions)
- ctx.input_shape = inputs.shape
-
- # Optionally reduce by target length
- if reduction == "mean":
- scales = [(1 / len(t) if len(t) > 0 else 1.0) for t in targets]
- else:
- scales = [1.0] * B
-
- ctx.scales = scales
-
- loss = torch.tensor([l.item() * s for l, s in zip(losses, scales)])
- return torch.mean(loss.to(inputs.device))
-
- @staticmethod
- def backward(ctx, grad_output) -> Tuple:
- losses, emissions_graphs, transitions = ctx.graphs
- scales = ctx.scales
-
- B, T, C = ctx.input_shape
- calc_emissions = ctx.needs_input_grad[0]
- input_grad = torch.empty((B, T, C)) if calc_emissions else None
-
- def process(b: int) -> None:
- scale = make_scalar_graph(scales[b])
- gtn.backward(losses[b], scale)
- emissions = emissions_graphs[b]
- if calc_emissions:
- grad = emissions.grad().weights_to_numpy()
- input_grad[b] = torch.tensor(grad).view(1, T, C)
-
- gtn.parallel_for(process, range(B))
-
- if calc_emissions:
- input_grad = input_grad.to(grad_output.device)
- input_grad *= grad_output / B
-
- if ctx.needs_input_grad[4]:
- grad = transitions.grad().weights_to_numpy()
- transition_grad = torch.tensor(grad).to(grad_output.device)
- transition_grad *= grad_output / B
- else:
- transition_grad = None
-
- return (
- input_grad,
- None, # target
- None, # tokens
- None, # lexicon
- transition_grad, # transition params
- None, # transitions graph
- None,
- )
-
-
-TransducerLoss = TransducerLossFunction.apply
-
-
-class Transducer(nn.Module):
- def __init__(
- self,
- tokens: List,
- graphemes_to_idx: Dict,
- ngram: int = 0,
- transitions: str = None,
- blank: str = "none",
- allow_repeats: bool = True,
- reduction: str = "none",
- ) -> None:
- """A generic transducer loss function.
-
- Args:
- tokens (List) : A list of iterable objects (e.g. strings, tuples, etc)
- representing the output tokens of the model (e.g. letters,
- word-pieces, words). For example ["a", "b", "ab", "ba", "aba"]
- could be a list of sub-word tokens.
- graphemes_to_idx (dict) : A dictionary mapping grapheme units (e.g.
- "a", "b", ..) to their corresponding integer index.
- ngram (int) : Order of the token-level transition model. If `ngram=0`
- then no transition model is used.
- blank (string) : Specifies the usage of blank token
- 'none' - do not use blank token
- 'optional' - allow an optional blank inbetween tokens
- 'forced' - force a blank inbetween tokens (also referred to as garbage token)
- allow_repeats (boolean) : If false, then we don't allow paths with
- consecutive tokens in the alignment graph. This keeps the graph
- unambiguous in the sense that the same input cannot transduce to
- different outputs.
- """
- super().__init__()
- if blank not in ["optional", "forced", "none"]:
- raise ValueError(
- "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']"
- )
- self.tokens = make_token_graph(tokens, blank=blank, allow_repeats=allow_repeats)
- self.lexicon = make_lexicon_graph(tokens, graphemes_to_idx)
- self.ngram = ngram
- if ngram > 0 and transitions is not None:
- raise ValueError("Only one of ngram and transitions may be specified")
-
- if ngram > 0:
- transitions = make_transitions_graph(
- ngram, len(tokens) + int(blank != "none"), True
- )
-
- if transitions is not None:
- self.transitions = transitions
- self.transitions.arc_sort()
- self.transitions_params = nn.Parameter(
- torch.zeros(self.transitions.num_arcs())
- )
- else:
- self.transitions = None
- self.transitions_params = None
- self.reduction = reduction
-
- def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss:
- TransducerLoss(
- inputs,
- targets,
- self.tokens,
- self.lexicon,
- self.transitions_params,
- self.transitions,
- self.reduction,
- )
-
- def viterbi(self, outputs: Tensor) -> List[Tensor]:
- B, T, C = outputs.shape
-
- if self.transitions is not None:
- cpu_data = self.transition_params.cpu().contiguous()
- self.transitions.set_weights(cpu_data.data_ptr())
- self.transitions.calc_grad = False
-
- self.tokens.arc_sort()
-
- paths = [None] * B
-
- def process(b: int) -> None:
- emissions = gtn.linear_graph(T, C, False)
- cpu_data = outputs[b].cpu().contiguous()
- emissions.set_weights(cpu_data.data_ptr())
-
- if self.transitions is not None:
- full_graph = gtn.intersect(emissions, self.transitions)
- else:
- full_graph = emissions
-
- # Find the best path and remove back-off arcs:
- path = gtn.remove(gtn.viterbi_path(full_graph))
-
- # Left compose the viterbi path with the "aligment to token"
- # transducer to get the outputs:
- path = gtn.compose(path, self.tokens)
-
- # When there are ambiguous paths (allow_repeats is true), we take
- # the shortest:
- path = gtn.viterbi_path(path)
- path = gtn.remove(gtn.project_output(path))
- paths[b] = path.labels_to_list()
-
- gtn.parallel_for(process, range(B))
- predictions = [torch.IntTensor(path) for path in paths]
- return predictions
-
-
-def load_transducer_loss(
- num_features: int,
- ngram: int,
- tokens: str,
- lexicon: str,
- transitions: str,
- blank: str,
- allow_repeats: bool,
- prepend_wordsep: bool = False,
- use_words: bool = False,
- data_dir: Optional[Union[str, Path]] = None,
- reduction: str = "mean",
-) -> Tuple[Transducer, int]:
- if data_dir is None:
- data_dir = (
- Path(__file__).resolve().parents[4] / "data" / "raw" / "iam" / "iamdb"
- )
- logger.debug(f"Using data dir: {data_dir}")
- if not data_dir.exists():
- raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
- else:
- data_dir = Path(data_dir)
- processed_path = (
- Path(__file__).resolve().parents[4] / "data" / "processed" / "iam_lines"
- )
- tokens_path = processed_path / tokens
- lexicon_path = processed_path / lexicon
-
- if transitions is not None:
- transitions = gtn.load(str(processed_path / transitions))
-
- preprocessor = Preprocessor(
- data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep,
- )
-
- num_tokens = preprocessor.num_tokens
-
- criterion = Transducer(
- preprocessor.tokens,
- preprocessor.graphemes_to_index,
- ngram=ngram,
- transitions=transitions,
- blank=blank,
- allow_repeats=allow_repeats,
- reduction=reduction,
- )
-
- return criterion, num_tokens + int(blank != "none")