summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-06-06 23:19:35 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-06-06 23:19:35 +0200
commit01d6e5fc066969283df99c759609df441151e9c5 (patch)
treeecd1459e142356d0c7f50a61307b760aca813248 /text_recognizer/networks
parentf4688482b4898c0b342d6ae59839dc27fbf856c6 (diff)
Working on fixing decoder transformer
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/cnn_transformer.py182
-rw-r--r--text_recognizer/networks/transducer/__init__.py3
-rw-r--r--text_recognizer/networks/transducer/tds_conv.py208
-rw-r--r--text_recognizer/networks/transducer/test.py60
-rw-r--r--text_recognizer/networks/transducer/transducer.py410
-rw-r--r--text_recognizer/networks/transformer/__init__.py2
-rw-r--r--text_recognizer/networks/transformer/layers.py5
-rw-r--r--text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py1
-rw-r--r--text_recognizer/networks/transformer/transformer.py7
-rw-r--r--text_recognizer/networks/util.py39
10 files changed, 9 insertions, 908 deletions
diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py
deleted file mode 100644
index 80798e1..0000000
--- a/text_recognizer/networks/cnn_transformer.py
+++ /dev/null
@@ -1,182 +0,0 @@
-# """A Transformer with a cnn backbone.
-#
-# The network encodes a image with a convolutional backbone to a latent representation,
-# i.e. feature maps. A 2d positional encoding is applied to the feature maps for
-# spatial information. The resulting feature are then set to a transformer decoder
-# together with the target tokens.
-#
-# TODO: Local attention for lower layer in attention.
-#
-# """
-# import importlib
-# import math
-# from typing import Dict, Optional, Union, Sequence, Type
-#
-# from einops import rearrange
-# from omegaconf import DictConfig, OmegaConf
-# import torch
-# from torch import nn
-# from torch import Tensor
-#
-# from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS
-# from text_recognizer.networks.transformer import (
-# Decoder,
-# DecoderLayer,
-# PositionalEncoding,
-# PositionalEncoding2D,
-# target_padding_mask,
-# )
-#
-# NUM_WORD_PIECES = 1000
-#
-#
-# class CNNTransformer(nn.Module):
-# def __init__(
-# self,
-# input_dim: Sequence[int],
-# output_dims: Sequence[int],
-# encoder: Union[DictConfig, Dict],
-# vocab_size: Optional[int] = None,
-# num_decoder_layers: int = 4,
-# hidden_dim: int = 256,
-# num_heads: int = 4,
-# expansion_dim: int = 1024,
-# dropout_rate: float = 0.1,
-# transformer_activation: str = "glu",
-# *args,
-# **kwargs,
-# ) -> None:
-# super().__init__()
-# self.vocab_size = (
-# NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size
-# )
-# self.pad_index = 3 # TODO: fix me
-# self.hidden_dim = hidden_dim
-# self.max_output_length = output_dims[0]
-#
-# # Image backbone
-# self.encoder = self._configure_encoder(encoder)
-# self.encoder_proj = nn.Conv2d(256, hidden_dim, kernel_size=1)
-# self.feature_map_encoding = PositionalEncoding2D(
-# hidden_dim=hidden_dim, max_h=input_dim[1], max_w=input_dim[2]
-# )
-#
-# # Target token embedding
-# self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim)
-# self.trg_position_encoding = PositionalEncoding(
-# hidden_dim, dropout_rate, max_len=output_dims[0]
-# )
-#
-# # Transformer decoder
-# self.decoder = Decoder(
-# decoder_layer=DecoderLayer(
-# hidden_dim=hidden_dim,
-# num_heads=num_heads,
-# expansion_dim=expansion_dim,
-# dropout_rate=dropout_rate,
-# activation=transformer_activation,
-# ),
-# num_layers=num_decoder_layers,
-# norm=nn.LayerNorm(hidden_dim),
-# )
-#
-# # Classification head
-# self.head = nn.Linear(hidden_dim, self.vocab_size)
-#
-# # Initialize weights
-# self._init_weights()
-#
-# def _init_weights(self) -> None:
-# """Initialize network weights."""
-# self.trg_embedding.weight.data.uniform_(-0.1, 0.1)
-# self.head.bias.data.zero_()
-# self.head.weight.data.uniform_(-0.1, 0.1)
-#
-# nn.init.kaiming_normal_(
-# self.encoder_proj.weight.data,
-# a=0,
-# mode="fan_out",
-# nonlinearity="relu",
-# )
-# if self.encoder_proj.bias is not None:
-# _, fan_out = nn.init._calculate_fan_in_and_fan_out(
-# self.encoder_proj.weight.data
-# )
-# bound = 1 / math.sqrt(fan_out)
-# nn.init.normal_(self.encoder_proj.bias, -bound, bound)
-#
-# @staticmethod
-# def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]:
-# encoder = OmegaConf.create(encoder)
-# args = encoder.args or {}
-# network_module = importlib.import_module("text_recognizer.networks")
-# encoder_class = getattr(network_module, encoder.type)
-# return encoder_class(**args)
-#
-# def encode(self, image: Tensor) -> Tensor:
-# """Extracts image features with backbone.
-#
-# Args:
-# image (Tensor): Image(s) of handwritten text.
-#
-# Retuns:
-# Tensor: Image features.
-#
-# Shapes:
-# - image: :math: `(B, C, H, W)`
-# - latent: :math: `(B, T, C)`
-#
-# """
-# # Extract image features.
-# image_features = self.encoder(image)
-# image_features = self.encoder_proj(image_features)
-#
-# # Add 2d encoding to the feature maps.
-# image_features = self.feature_map_encoding(image_features)
-#
-# # Collapse features maps height and width.
-# image_features = rearrange(image_features, "b c h w -> b (h w) c")
-# return image_features
-#
-# def decode(self, memory: Tensor, trg: Tensor) -> Tensor:
-# """Decodes image features with transformer decoder."""
-# trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index)
-# trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim)
-# trg = rearrange(trg, "b t d -> t b d")
-# trg = self.trg_position_encoding(trg)
-# trg = rearrange(trg, "t b d -> b t d")
-# out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None)
-# logits = self.head(out)
-# return logits
-#
-# def forward(self, image: Tensor, trg: Tensor) -> Tensor:
-# image_features = self.encode(image)
-# output = self.decode(image_features, trg)
-# output = rearrange(output, "b t c -> b c t")
-# return output
-#
-# def predict(self, image: Tensor) -> Tensor:
-# """Transcribes text in image(s)."""
-# bsz = image.shape[0]
-# image_features = self.encode(image)
-#
-# output_tokens = (
-# (torch.ones((bsz, self.max_output_length)) * self.pad_index)
-# .type_as(image)
-# .long()
-# )
-# output_tokens[:, 0] = self.start_index
-# for i in range(1, self.max_output_length):
-# trg = output_tokens[:, :i]
-# output = self.decode(image_features, trg)
-# output = torch.argmax(output, dim=-1)
-# output_tokens[:, i] = output[-1:]
-#
-# # Set all tokens after end token to be padding.
-# for i in range(1, self.max_output_length):
-# indices = output_tokens[:, i - 1] == self.end_index | (
-# output_tokens[:, i - 1] == self.pad_index
-# )
-# output_tokens[indices, i] = self.pad_index
-#
-# return output_tokens
diff --git a/text_recognizer/networks/transducer/__init__.py b/text_recognizer/networks/transducer/__init__.py
deleted file mode 100644
index 8c19a01..0000000
--- a/text_recognizer/networks/transducer/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-"""Transducer modules."""
-from .tds_conv import TDS2d
-from .transducer import load_transducer_loss, Transducer
diff --git a/text_recognizer/networks/transducer/tds_conv.py b/text_recognizer/networks/transducer/tds_conv.py
deleted file mode 100644
index 5fb8ba9..0000000
--- a/text_recognizer/networks/transducer/tds_conv.py
+++ /dev/null
@@ -1,208 +0,0 @@
-"""Time-Depth Separable Convolutions.
-
-References:
- https://arxiv.org/abs/1904.02619
- https://arxiv.org/pdf/2010.01003.pdf
-
-Code stolen from:
- https://github.com/facebookresearch/gtn_applications
-
-
-"""
-from typing import List, Tuple
-
-from einops import rearrange
-import gtn
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class TDSBlock2d(nn.Module):
- """Internal block of a 2D TDSC network."""
-
- def __init__(
- self,
- in_channels: int,
- img_depth: int,
- kernel_size: Tuple[int],
- dropout_rate: float,
- ) -> None:
- super().__init__()
-
- self.in_channels = in_channels
- self.img_depth = img_depth
- self.kernel_size = kernel_size
- self.dropout_rate = dropout_rate
- self.fc_dim = in_channels * img_depth
-
- # Network placeholders.
- self.conv = None
- self.mlp = None
- self.instance_norm = None
-
- self._build_block()
-
- def _build_block(self) -> None:
- # Convolutional block.
- self.conv = nn.Sequential(
- nn.Conv3d(
- in_channels=self.in_channels,
- out_channels=self.in_channels,
- kernel_size=(1, self.kernel_size[0], self.kernel_size[1]),
- padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2),
- ),
- nn.ReLU(inplace=True),
- nn.Dropout(self.dropout_rate),
- )
-
- # MLP block.
- self.mlp = nn.Sequential(
- nn.Linear(self.fc_dim, self.fc_dim),
- nn.ReLU(inplace=True),
- nn.Dropout(self.dropout_rate),
- nn.Linear(self.fc_dim, self.fc_dim),
- nn.Dropout(self.dropout_rate),
- )
-
- # Instance norm.
- self.instance_norm = nn.ModuleList(
- [
- nn.InstanceNorm2d(self.fc_dim, affine=True),
- nn.InstanceNorm2d(self.fc_dim, affine=True),
- ]
- )
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass.
-
- Args:
- x (Tensor): Input tensor.
-
- Shape:
- - x: :math: `(B, CD, H, W)`
-
- Returns:
- Tensor: Output tensor.
-
- """
- B, CD, H, W = x.shape
- C, D = self.in_channels, self.img_depth
- residual = x
- x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D)
- x = self.conv(x)
- x = rearrange(x, "b c d h w -> b (c d) h w")
- x += residual
-
- x = self.instance_norm[0](x)
-
- x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x
- x + self.instance_norm[1](x)
-
- # Output shape: [B, CD, H, W]
- return x
-
-
-class TDS2d(nn.Module):
- """TDS Netowrk.
-
- Structure is the following:
- Downsample layer -> TDS2d group -> ... -> Linear output layer
-
-
- """
-
- def __init__(
- self,
- input_dim: int,
- output_dim: int,
- depth: int,
- tds_groups: Tuple[int],
- kernel_size: Tuple[int],
- dropout_rate: float,
- in_channels: int = 1,
- ) -> None:
- super().__init__()
-
- self.in_channels = in_channels
- self.input_dim = input_dim
- self.output_dim = output_dim
- self.depth = depth
- self.tds_groups = tds_groups
- self.kernel_size = kernel_size
- self.dropout_rate = dropout_rate
-
- self.tds = None
- self.fc = None
-
- self._build_network()
-
- def _build_network(self) -> None:
- in_channels = self.in_channels
- modules = []
- stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups])
- if self.input_dim % stride_h:
- raise RuntimeError(
- f"Image height not divisible by total stride {stride_h}."
- )
-
- for tds_group in self.tds_groups:
- # Add downsample layer.
- out_channels = self.depth * tds_group["channels"]
- modules.extend(
- [
- nn.Conv2d(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=self.kernel_size,
- padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2),
- stride=tds_group["stride"],
- ),
- nn.ReLU(inplace=True),
- nn.Dropout(self.dropout_rate),
- nn.InstanceNorm2d(out_channels, affine=True),
- ]
- )
-
- for _ in range(tds_group["num_blocks"]):
- modules.append(
- TDSBlock2d(
- tds_group["channels"],
- self.depth,
- self.kernel_size,
- self.dropout_rate,
- )
- )
-
- in_channels = out_channels
-
- self.tds = nn.Sequential(*modules)
- self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim)
-
- def forward(self, x: Tensor) -> Tensor:
- """Forward pass.
-
- Args:
- x (Tensor): Input tensor.
-
- Shape:
- - x: :math: `(B, H, W)`
-
- Returns:
- Tensor: Output tensor.
-
- """
- if len(x.shape) == 4:
- x = x.squeeze(1) # Squeeze the channel dim away.
-
- B, H, W = x.shape
- x = rearrange(
- x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels
- )
- x = self.tds(x)
-
- # x shape: [B, C, H, W]
- x = rearrange(x, "b c h w -> b w (c h)")
-
- return self.fc(x)
diff --git a/text_recognizer/networks/transducer/test.py b/text_recognizer/networks/transducer/test.py
deleted file mode 100644
index cadcecc..0000000
--- a/text_recognizer/networks/transducer/test.py
+++ /dev/null
@@ -1,60 +0,0 @@
-import torch
-from torch import nn
-
-from text_recognizer.networks.transducer import load_transducer_loss, Transducer
-import unittest
-
-
-class TestTransducer(unittest.TestCase):
- def test_viterbi(self):
- T = 5
- N = 4
- B = 2
-
- # fmt: off
- emissions1 = torch.tensor((
- 0, 4, 0, 1,
- 0, 2, 1, 1,
- 0, 0, 0, 2,
- 0, 0, 0, 2,
- 8, 0, 0, 2,
- ),
- dtype=torch.float,
- ).view(T, N)
- emissions2 = torch.tensor((
- 0, 2, 1, 7,
- 0, 2, 9, 1,
- 0, 0, 0, 2,
- 0, 0, 5, 2,
- 1, 0, 0, 2,
- ),
- dtype=torch.float,
- ).view(T, N)
- # fmt: on
-
- # Test without blank:
- labels = [[1, 3, 0], [3, 2, 3, 2, 3]]
- transducer = Transducer(
- tokens=["a", "b", "c", "d"],
- graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3},
- blank="none",
- )
- emissions = torch.stack([emissions1, emissions2], dim=0)
- predictions = transducer.viterbi(emissions)
- self.assertEqual([p.tolist() for p in predictions], labels)
-
- # Test with blank without repeats:
- labels = [[1, 0], [2, 2]]
- transducer = Transducer(
- tokens=["a", "b", "c"],
- graphemes_to_idx={"a": 0, "b": 1, "c": 2},
- blank="optional",
- allow_repeats=False,
- )
- emissions = torch.stack([emissions1, emissions2], dim=0)
- predictions = transducer.viterbi(emissions)
- self.assertEqual([p.tolist() for p in predictions], labels)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py
deleted file mode 100644
index d7e3d08..0000000
--- a/text_recognizer/networks/transducer/transducer.py
+++ /dev/null
@@ -1,410 +0,0 @@
-"""Transducer and the transducer loss function.py
-
-Stolen from:
- https://github.com/facebookresearch/gtn_applications/blob/master/transducer.py
-
-"""
-from pathlib import Path
-import itertools
-from typing import Dict, List, Optional, Union, Tuple
-
-from loguru import logger
-import gtn
-import torch
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.datasets.iam_preprocessor import Preprocessor
-
-
-def make_scalar_graph(weight) -> gtn.Graph:
- scalar = gtn.Graph()
- scalar.add_node(True)
- scalar.add_node(False, True)
- scalar.add_arc(0, 1, 0, 0, weight)
- return scalar
-
-
-def make_chain_graph(sequence) -> gtn.Graph:
- graph = gtn.Graph(False)
- graph.add_node(True)
- for i, s in enumerate(sequence):
- graph.add_node(False, i == (len(sequence) - 1))
- graph.add_arc(i, i + 1, s)
- return graph
-
-
-def make_transitions_graph(
- ngram: int, num_tokens: int, calc_grad: bool = False
-) -> gtn.Graph:
- transitions = gtn.Graph(calc_grad)
- transitions.add_node(True, ngram == 1)
-
- state_map = {(): 0}
-
- # First build transitions which include <s>:
- for n in range(1, ngram):
- for state in itertools.product(range(num_tokens), repeat=n):
- in_idx = state_map[state[:-1]]
- out_idx = transitions.add_node(False, ngram == 1)
- state_map[state] = out_idx
- transitions.add_arc(in_idx, out_idx, state[-1])
-
- for state in itertools.product(range(num_tokens), repeat=ngram):
- state_idx = state_map[state[:-1]]
- new_state_idx = state_map[state[1:]]
- # p(state[-1] | state[:-1])
- transitions.add_arc(state_idx, new_state_idx, state[-1])
-
- if ngram > 1:
- # Build transitions which include </s>:
- end_idx = transitions.add_node(False, True)
- for in_idx in range(end_idx):
- transitions.add_arc(in_idx, end_idx, gtn.epsilon)
-
- return transitions
-
-
-def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph:
- """Constructs a graph which transduces letters to word pieces."""
- graph = gtn.Graph(False)
- graph.add_node(True, True)
- for i, wp in enumerate(word_pieces):
- prev = 0
- for l in wp[:-1]:
- n = graph.add_node()
- graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon)
- prev = n
- graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i)
- graph.arc_sort()
- return graph
-
-
-def make_token_graph(
- token_list: List, blank: str = "none", allow_repeats: bool = True
-) -> gtn.Graph:
- """Constructs a graph with all the individual token transition models."""
- if not allow_repeats and blank != "optional":
- raise ValueError("Must use blank='optional' if disallowing repeats.")
-
- ntoks = len(token_list)
- graph = gtn.Graph(False)
-
- # Creating nodes
- graph.add_node(True, True)
- for i in range(ntoks):
- # We can consume one or more consecutive word
- # pieces for each emission:
- # E.g. [ab, ab, ab] transduces to [ab]
- graph.add_node(False, blank != "forced")
-
- if blank != "none":
- graph.add_node()
-
- # Creating arcs
- if blank != "none":
- # Blank index is assumed to be last (ntoks)
- graph.add_arc(0, ntoks + 1, ntoks, gtn.epsilon)
- graph.add_arc(ntoks + 1, 0, gtn.epsilon)
-
- for i in range(ntoks):
- graph.add_arc((ntoks + 1) if blank == "forced" else 0, i + 1, i)
- graph.add_arc(i + 1, i + 1, i, gtn.epsilon)
-
- if allow_repeats:
- if blank == "forced":
- # Allow transitions from token to blank only
- graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon)
- else:
- # Allow transition from token to blank and all other tokens
- graph.add_arc(i + 1, 0, gtn.epsilon)
-
- else:
- # allow transitions to blank and all other tokens except the same token
- graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon)
- for j in range(ntoks):
- if i != j:
- graph.add_arc(i + 1, j + 1, j, j)
-
- return graph
-
-
-class TransducerLossFunction(torch.autograd.Function):
- @staticmethod
- def forward(
- ctx,
- inputs,
- targets,
- tokens,
- lexicon,
- transition_params=None,
- transitions=None,
- reduction="none",
- ) -> Tensor:
- B, T, C = inputs.shape
-
- losses = [None] * B
- emissions_graphs = [None] * B
-
- if transitions is not None:
- if transition_params is None:
- raise ValueError("Specified transitions, but not transition params.")
-
- cpu_data = transition_params.cpu().contiguous()
- transitions.set_weights(cpu_data.data_ptr())
- transitions.calc_grad = transition_params.requires_grad
- transitions.zero_grad()
-
- def process(b: int) -> None:
- # Create emission graph:
- emissions = gtn.linear_graph(T, C, inputs.requires_grad)
- cpu_data = inputs[b].cpu().contiguous()
- emissions.set_weights(cpu_data.data_ptr())
- target = make_chain_graph(targets[b])
- target.arc_sort(True)
-
- # Create token tot grapheme decomposition graph
- tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon)))
- tokens_target.arc_sort()
-
- # Create alignment graph:
- aligments = gtn.project_input(
- gtn.remove(gtn.compose(tokens, tokens_target))
- )
- aligments.arc_sort()
-
- # Add transitions scores:
- if transitions is not None:
- aligments = gtn.intersect(transitions, aligments)
- aligments.arc_sort()
-
- loss = gtn.forward_score(gtn.intersect(emissions, aligments))
-
- # Normalize if needed:
- if transitions is not None:
- norm = gtn.forward_score(gtn.intersect(emissions, transitions))
- loss = gtn.subtract(loss, norm)
-
- losses[b] = gtn.negate(loss)
-
- # Save for backward:
- if emissions.calc_grad:
- emissions_graphs[b] = emissions
-
- gtn.parallel_for(process, range(B))
-
- ctx.graphs = (losses, emissions_graphs, transitions)
- ctx.input_shape = inputs.shape
-
- # Optionally reduce by target length
- if reduction == "mean":
- scales = [(1 / len(t) if len(t) > 0 else 1.0) for t in targets]
- else:
- scales = [1.0] * B
-
- ctx.scales = scales
-
- loss = torch.tensor([l.item() * s for l, s in zip(losses, scales)])
- return torch.mean(loss.to(inputs.device))
-
- @staticmethod
- def backward(ctx, grad_output) -> Tuple:
- losses, emissions_graphs, transitions = ctx.graphs
- scales = ctx.scales
-
- B, T, C = ctx.input_shape
- calc_emissions = ctx.needs_input_grad[0]
- input_grad = torch.empty((B, T, C)) if calc_emissions else None
-
- def process(b: int) -> None:
- scale = make_scalar_graph(scales[b])
- gtn.backward(losses[b], scale)
- emissions = emissions_graphs[b]
- if calc_emissions:
- grad = emissions.grad().weights_to_numpy()
- input_grad[b] = torch.tensor(grad).view(1, T, C)
-
- gtn.parallel_for(process, range(B))
-
- if calc_emissions:
- input_grad = input_grad.to(grad_output.device)
- input_grad *= grad_output / B
-
- if ctx.needs_input_grad[4]:
- grad = transitions.grad().weights_to_numpy()
- transition_grad = torch.tensor(grad).to(grad_output.device)
- transition_grad *= grad_output / B
- else:
- transition_grad = None
-
- return (
- input_grad,
- None, # target
- None, # tokens
- None, # lexicon
- transition_grad, # transition params
- None, # transitions graph
- None,
- )
-
-
-TransducerLoss = TransducerLossFunction.apply
-
-
-class Transducer(nn.Module):
- def __init__(
- self,
- tokens: List,
- graphemes_to_idx: Dict,
- ngram: int = 0,
- transitions: str = None,
- blank: str = "none",
- allow_repeats: bool = True,
- reduction: str = "none",
- ) -> None:
- """A generic transducer loss function.
-
- Args:
- tokens (List) : A list of iterable objects (e.g. strings, tuples, etc)
- representing the output tokens of the model (e.g. letters,
- word-pieces, words). For example ["a", "b", "ab", "ba", "aba"]
- could be a list of sub-word tokens.
- graphemes_to_idx (dict) : A dictionary mapping grapheme units (e.g.
- "a", "b", ..) to their corresponding integer index.
- ngram (int) : Order of the token-level transition model. If `ngram=0`
- then no transition model is used.
- blank (string) : Specifies the usage of blank token
- 'none' - do not use blank token
- 'optional' - allow an optional blank inbetween tokens
- 'forced' - force a blank inbetween tokens (also referred to as garbage token)
- allow_repeats (boolean) : If false, then we don't allow paths with
- consecutive tokens in the alignment graph. This keeps the graph
- unambiguous in the sense that the same input cannot transduce to
- different outputs.
- """
- super().__init__()
- if blank not in ["optional", "forced", "none"]:
- raise ValueError(
- "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']"
- )
- self.tokens = make_token_graph(tokens, blank=blank, allow_repeats=allow_repeats)
- self.lexicon = make_lexicon_graph(tokens, graphemes_to_idx)
- self.ngram = ngram
- if ngram > 0 and transitions is not None:
- raise ValueError("Only one of ngram and transitions may be specified")
-
- if ngram > 0:
- transitions = make_transitions_graph(
- ngram, len(tokens) + int(blank != "none"), True
- )
-
- if transitions is not None:
- self.transitions = transitions
- self.transitions.arc_sort()
- self.transitions_params = nn.Parameter(
- torch.zeros(self.transitions.num_arcs())
- )
- else:
- self.transitions = None
- self.transitions_params = None
- self.reduction = reduction
-
- def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss:
- TransducerLoss(
- inputs,
- targets,
- self.tokens,
- self.lexicon,
- self.transitions_params,
- self.transitions,
- self.reduction,
- )
-
- def viterbi(self, outputs: Tensor) -> List[Tensor]:
- B, T, C = outputs.shape
-
- if self.transitions is not None:
- cpu_data = self.transition_params.cpu().contiguous()
- self.transitions.set_weights(cpu_data.data_ptr())
- self.transitions.calc_grad = False
-
- self.tokens.arc_sort()
-
- paths = [None] * B
-
- def process(b: int) -> None:
- emissions = gtn.linear_graph(T, C, False)
- cpu_data = outputs[b].cpu().contiguous()
- emissions.set_weights(cpu_data.data_ptr())
-
- if self.transitions is not None:
- full_graph = gtn.intersect(emissions, self.transitions)
- else:
- full_graph = emissions
-
- # Find the best path and remove back-off arcs:
- path = gtn.remove(gtn.viterbi_path(full_graph))
-
- # Left compose the viterbi path with the "aligment to token"
- # transducer to get the outputs:
- path = gtn.compose(path, self.tokens)
-
- # When there are ambiguous paths (allow_repeats is true), we take
- # the shortest:
- path = gtn.viterbi_path(path)
- path = gtn.remove(gtn.project_output(path))
- paths[b] = path.labels_to_list()
-
- gtn.parallel_for(process, range(B))
- predictions = [torch.IntTensor(path) for path in paths]
- return predictions
-
-
-def load_transducer_loss(
- num_features: int,
- ngram: int,
- tokens: str,
- lexicon: str,
- transitions: str,
- blank: str,
- allow_repeats: bool,
- prepend_wordsep: bool = False,
- use_words: bool = False,
- data_dir: Optional[Union[str, Path]] = None,
- reduction: str = "mean",
-) -> Tuple[Transducer, int]:
- if data_dir is None:
- data_dir = (
- Path(__file__).resolve().parents[4] / "data" / "raw" / "iam" / "iamdb"
- )
- logger.debug(f"Using data dir: {data_dir}")
- if not data_dir.exists():
- raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
- else:
- data_dir = Path(data_dir)
- processed_path = (
- Path(__file__).resolve().parents[4] / "data" / "processed" / "iam_lines"
- )
- tokens_path = processed_path / tokens
- lexicon_path = processed_path / lexicon
-
- if transitions is not None:
- transitions = gtn.load(str(processed_path / transitions))
-
- preprocessor = Preprocessor(
- data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep,
- )
-
- num_tokens = preprocessor.num_tokens
-
- criterion = Transducer(
- preprocessor.tokens,
- preprocessor.graphemes_to_index,
- ngram=ngram,
- transitions=transitions,
- blank=blank,
- allow_repeats=allow_repeats,
- reduction=reduction,
- )
-
- return criterion, num_tokens + int(blank != "none")
diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py
index a3f3011..d9e63ef 100644
--- a/text_recognizer/networks/transformer/__init__.py
+++ b/text_recognizer/networks/transformer/__init__.py
@@ -1 +1,3 @@
"""Transformer modules."""
+from .nystromer.nystromer import Nystromer
+from .vit import ViT
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index b2c703f..a44a525 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -1,8 +1,6 @@
"""Generates the attention layer architecture."""
from functools import partial
-from typing import Any, Dict, Optional, Type
-
-from click.types import Tuple
+from typing import Any, Dict, Optional, Tuple, Type
from torch import nn, Tensor
@@ -30,6 +28,7 @@ class AttentionLayers(nn.Module):
pre_norm: bool = True,
) -> None:
super().__init__()
+ self.dim = dim
attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs)
norm_fn = partial(norm_fn, dim)
ff_fn = partial(ff_fn, dim=dim, **ff_kwargs)
diff --git a/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py b/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py
index 9466f6e..7140537 100644
--- a/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py
+++ b/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py
@@ -1,4 +1,5 @@
"""Absolute positional embedding."""
+import torch
from torch import nn, Tensor
diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py
index 60ab1ce..31088b4 100644
--- a/text_recognizer/networks/transformer/transformer.py
+++ b/text_recognizer/networks/transformer/transformer.py
@@ -19,7 +19,9 @@ class Transformer(nn.Module):
emb_dropout: float = 0.0,
use_pos_emb: bool = True,
) -> None:
+ super().__init__()
dim = attn_layers.dim
+ self.attn_layers = attn_layers
emb_dim = emb_dim if emb_dim is not None else dim
self.max_seq_len = max_seq_len
@@ -32,7 +34,6 @@ class Transformer(nn.Module):
)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
- self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
self._init_weights()
@@ -45,12 +46,12 @@ class Transformer(nn.Module):
def forward(
self,
x: Tensor,
- mask: Optional[Tensor],
+ mask: Optional[Tensor] = None,
return_embeddings: bool = False,
**kwargs: Any
) -> Tensor:
b, n, device = *x.shape, x.device
- x += self.token_emb(x)
+ x = self.token_emb(x)
if self.pos_emb is not None:
x += self.pos_emb(x)
x = self.emb_dropout(x)
diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py
index 9c6b151..05b10a8 100644
--- a/text_recognizer/networks/util.py
+++ b/text_recognizer/networks/util.py
@@ -22,42 +22,3 @@ def activation_function(activation: str) -> Type[nn.Module]:
]
)
return activation_fns[activation.lower()]
-
-
-# def configure_backbone(backbone: Union[OmegaConf, NamedTuple]) -> Type[nn.Module]:
-# """Loads a backbone network."""
-# network_module = importlib.import_module("text_recognizer.networks")
-# backbone_class = getattr(network_module, backbone.type)
-#
-# if "pretrained" in backbone.args:
-# logger.info("Loading pretrained backbone.")
-# checkpoint_file = Path(__file__).resolve().parents[2] / backbone.args.pop(
-# "pretrained"
-# )
-#
-# # Loading state directory.
-# state_dict = torch.load(checkpoint_file)
-# network_args = state_dict["network_args"]
-# weights = state_dict["model_state"]
-#
-# freeze = False
-# if "freeze" in backbone.args and backbone.args["freeze"] is True:
-# backbone.args.pop("freeze")
-# freeze = True
-#
-# # Initializes the network with trained weights.
-# backbone_ = backbone_(**backbone.args)
-# backbone_.load_state_dict(weights)
-# if freeze:
-# for params in backbone_.parameters():
-# params.requires_grad = False
-# else:
-# backbone_ = getattr(network_module, backbone.type)
-# backbone_ = backbone_(**backbone.args)
-#
-# if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None:
-# backbone = nn.Sequential(
-# *list(backbone.children())[:][: -backbone_args["remove_layers"]]
-# )
-#
-# return backbone