From 01d6e5fc066969283df99c759609df441151e9c5 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 6 Jun 2021 23:19:35 +0200 Subject: Working on fixing decoder transformer --- text_recognizer/networks/cnn_transformer.py | 182 --------- text_recognizer/networks/transducer/__init__.py | 3 - text_recognizer/networks/transducer/tds_conv.py | 208 ----------- text_recognizer/networks/transducer/test.py | 60 --- text_recognizer/networks/transducer/transducer.py | 410 --------------------- text_recognizer/networks/transformer/__init__.py | 2 + text_recognizer/networks/transformer/layers.py | 5 +- .../positional_encodings/absolute_embedding.py | 1 + .../networks/transformer/transformer.py | 7 +- text_recognizer/networks/util.py | 39 -- 10 files changed, 9 insertions(+), 908 deletions(-) delete mode 100644 text_recognizer/networks/cnn_transformer.py delete mode 100644 text_recognizer/networks/transducer/__init__.py delete mode 100644 text_recognizer/networks/transducer/tds_conv.py delete mode 100644 text_recognizer/networks/transducer/test.py delete mode 100644 text_recognizer/networks/transducer/transducer.py (limited to 'text_recognizer/networks') diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py deleted file mode 100644 index 80798e1..0000000 --- a/text_recognizer/networks/cnn_transformer.py +++ /dev/null @@ -1,182 +0,0 @@ -# """A Transformer with a cnn backbone. -# -# The network encodes a image with a convolutional backbone to a latent representation, -# i.e. feature maps. A 2d positional encoding is applied to the feature maps for -# spatial information. The resulting feature are then set to a transformer decoder -# together with the target tokens. -# -# TODO: Local attention for lower layer in attention. -# -# """ -# import importlib -# import math -# from typing import Dict, Optional, Union, Sequence, Type -# -# from einops import rearrange -# from omegaconf import DictConfig, OmegaConf -# import torch -# from torch import nn -# from torch import Tensor -# -# from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS -# from text_recognizer.networks.transformer import ( -# Decoder, -# DecoderLayer, -# PositionalEncoding, -# PositionalEncoding2D, -# target_padding_mask, -# ) -# -# NUM_WORD_PIECES = 1000 -# -# -# class CNNTransformer(nn.Module): -# def __init__( -# self, -# input_dim: Sequence[int], -# output_dims: Sequence[int], -# encoder: Union[DictConfig, Dict], -# vocab_size: Optional[int] = None, -# num_decoder_layers: int = 4, -# hidden_dim: int = 256, -# num_heads: int = 4, -# expansion_dim: int = 1024, -# dropout_rate: float = 0.1, -# transformer_activation: str = "glu", -# *args, -# **kwargs, -# ) -> None: -# super().__init__() -# self.vocab_size = ( -# NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size -# ) -# self.pad_index = 3 # TODO: fix me -# self.hidden_dim = hidden_dim -# self.max_output_length = output_dims[0] -# -# # Image backbone -# self.encoder = self._configure_encoder(encoder) -# self.encoder_proj = nn.Conv2d(256, hidden_dim, kernel_size=1) -# self.feature_map_encoding = PositionalEncoding2D( -# hidden_dim=hidden_dim, max_h=input_dim[1], max_w=input_dim[2] -# ) -# -# # Target token embedding -# self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) -# self.trg_position_encoding = PositionalEncoding( -# hidden_dim, dropout_rate, max_len=output_dims[0] -# ) -# -# # Transformer decoder -# self.decoder = Decoder( -# decoder_layer=DecoderLayer( -# hidden_dim=hidden_dim, -# num_heads=num_heads, -# expansion_dim=expansion_dim, -# dropout_rate=dropout_rate, -# activation=transformer_activation, -# ), -# num_layers=num_decoder_layers, -# norm=nn.LayerNorm(hidden_dim), -# ) -# -# # Classification head -# self.head = nn.Linear(hidden_dim, self.vocab_size) -# -# # Initialize weights -# self._init_weights() -# -# def _init_weights(self) -> None: -# """Initialize network weights.""" -# self.trg_embedding.weight.data.uniform_(-0.1, 0.1) -# self.head.bias.data.zero_() -# self.head.weight.data.uniform_(-0.1, 0.1) -# -# nn.init.kaiming_normal_( -# self.encoder_proj.weight.data, -# a=0, -# mode="fan_out", -# nonlinearity="relu", -# ) -# if self.encoder_proj.bias is not None: -# _, fan_out = nn.init._calculate_fan_in_and_fan_out( -# self.encoder_proj.weight.data -# ) -# bound = 1 / math.sqrt(fan_out) -# nn.init.normal_(self.encoder_proj.bias, -bound, bound) -# -# @staticmethod -# def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: -# encoder = OmegaConf.create(encoder) -# args = encoder.args or {} -# network_module = importlib.import_module("text_recognizer.networks") -# encoder_class = getattr(network_module, encoder.type) -# return encoder_class(**args) -# -# def encode(self, image: Tensor) -> Tensor: -# """Extracts image features with backbone. -# -# Args: -# image (Tensor): Image(s) of handwritten text. -# -# Retuns: -# Tensor: Image features. -# -# Shapes: -# - image: :math: `(B, C, H, W)` -# - latent: :math: `(B, T, C)` -# -# """ -# # Extract image features. -# image_features = self.encoder(image) -# image_features = self.encoder_proj(image_features) -# -# # Add 2d encoding to the feature maps. -# image_features = self.feature_map_encoding(image_features) -# -# # Collapse features maps height and width. -# image_features = rearrange(image_features, "b c h w -> b (h w) c") -# return image_features -# -# def decode(self, memory: Tensor, trg: Tensor) -> Tensor: -# """Decodes image features with transformer decoder.""" -# trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) -# trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) -# trg = rearrange(trg, "b t d -> t b d") -# trg = self.trg_position_encoding(trg) -# trg = rearrange(trg, "t b d -> b t d") -# out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None) -# logits = self.head(out) -# return logits -# -# def forward(self, image: Tensor, trg: Tensor) -> Tensor: -# image_features = self.encode(image) -# output = self.decode(image_features, trg) -# output = rearrange(output, "b t c -> b c t") -# return output -# -# def predict(self, image: Tensor) -> Tensor: -# """Transcribes text in image(s).""" -# bsz = image.shape[0] -# image_features = self.encode(image) -# -# output_tokens = ( -# (torch.ones((bsz, self.max_output_length)) * self.pad_index) -# .type_as(image) -# .long() -# ) -# output_tokens[:, 0] = self.start_index -# for i in range(1, self.max_output_length): -# trg = output_tokens[:, :i] -# output = self.decode(image_features, trg) -# output = torch.argmax(output, dim=-1) -# output_tokens[:, i] = output[-1:] -# -# # Set all tokens after end token to be padding. -# for i in range(1, self.max_output_length): -# indices = output_tokens[:, i - 1] == self.end_index | ( -# output_tokens[:, i - 1] == self.pad_index -# ) -# output_tokens[indices, i] = self.pad_index -# -# return output_tokens diff --git a/text_recognizer/networks/transducer/__init__.py b/text_recognizer/networks/transducer/__init__.py deleted file mode 100644 index 8c19a01..0000000 --- a/text_recognizer/networks/transducer/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""Transducer modules.""" -from .tds_conv import TDS2d -from .transducer import load_transducer_loss, Transducer diff --git a/text_recognizer/networks/transducer/tds_conv.py b/text_recognizer/networks/transducer/tds_conv.py deleted file mode 100644 index 5fb8ba9..0000000 --- a/text_recognizer/networks/transducer/tds_conv.py +++ /dev/null @@ -1,208 +0,0 @@ -"""Time-Depth Separable Convolutions. - -References: - https://arxiv.org/abs/1904.02619 - https://arxiv.org/pdf/2010.01003.pdf - -Code stolen from: - https://github.com/facebookresearch/gtn_applications - - -""" -from typing import List, Tuple - -from einops import rearrange -import gtn -import numpy as np -import torch -from torch import nn -from torch import Tensor - - -class TDSBlock2d(nn.Module): - """Internal block of a 2D TDSC network.""" - - def __init__( - self, - in_channels: int, - img_depth: int, - kernel_size: Tuple[int], - dropout_rate: float, - ) -> None: - super().__init__() - - self.in_channels = in_channels - self.img_depth = img_depth - self.kernel_size = kernel_size - self.dropout_rate = dropout_rate - self.fc_dim = in_channels * img_depth - - # Network placeholders. - self.conv = None - self.mlp = None - self.instance_norm = None - - self._build_block() - - def _build_block(self) -> None: - # Convolutional block. - self.conv = nn.Sequential( - nn.Conv3d( - in_channels=self.in_channels, - out_channels=self.in_channels, - kernel_size=(1, self.kernel_size[0], self.kernel_size[1]), - padding=(0, self.kernel_size[0] // 2, self.kernel_size[1] // 2), - ), - nn.ReLU(inplace=True), - nn.Dropout(self.dropout_rate), - ) - - # MLP block. - self.mlp = nn.Sequential( - nn.Linear(self.fc_dim, self.fc_dim), - nn.ReLU(inplace=True), - nn.Dropout(self.dropout_rate), - nn.Linear(self.fc_dim, self.fc_dim), - nn.Dropout(self.dropout_rate), - ) - - # Instance norm. - self.instance_norm = nn.ModuleList( - [ - nn.InstanceNorm2d(self.fc_dim, affine=True), - nn.InstanceNorm2d(self.fc_dim, affine=True), - ] - ) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass. - - Args: - x (Tensor): Input tensor. - - Shape: - - x: :math: `(B, CD, H, W)` - - Returns: - Tensor: Output tensor. - - """ - B, CD, H, W = x.shape - C, D = self.in_channels, self.img_depth - residual = x - x = rearrange(x, "b (c d) h w -> b c d h w", c=C, d=D) - x = self.conv(x) - x = rearrange(x, "b c d h w -> b (c d) h w") - x += residual - - x = self.instance_norm[0](x) - - x = self.mlp(x.transpose(1, 3)).transpose(1, 3) + x - x + self.instance_norm[1](x) - - # Output shape: [B, CD, H, W] - return x - - -class TDS2d(nn.Module): - """TDS Netowrk. - - Structure is the following: - Downsample layer -> TDS2d group -> ... -> Linear output layer - - - """ - - def __init__( - self, - input_dim: int, - output_dim: int, - depth: int, - tds_groups: Tuple[int], - kernel_size: Tuple[int], - dropout_rate: float, - in_channels: int = 1, - ) -> None: - super().__init__() - - self.in_channels = in_channels - self.input_dim = input_dim - self.output_dim = output_dim - self.depth = depth - self.tds_groups = tds_groups - self.kernel_size = kernel_size - self.dropout_rate = dropout_rate - - self.tds = None - self.fc = None - - self._build_network() - - def _build_network(self) -> None: - in_channels = self.in_channels - modules = [] - stride_h = np.prod([grp["stride"][0] for grp in self.tds_groups]) - if self.input_dim % stride_h: - raise RuntimeError( - f"Image height not divisible by total stride {stride_h}." - ) - - for tds_group in self.tds_groups: - # Add downsample layer. - out_channels = self.depth * tds_group["channels"] - modules.extend( - [ - nn.Conv2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=self.kernel_size, - padding=(self.kernel_size[0] // 2, self.kernel_size[1] // 2), - stride=tds_group["stride"], - ), - nn.ReLU(inplace=True), - nn.Dropout(self.dropout_rate), - nn.InstanceNorm2d(out_channels, affine=True), - ] - ) - - for _ in range(tds_group["num_blocks"]): - modules.append( - TDSBlock2d( - tds_group["channels"], - self.depth, - self.kernel_size, - self.dropout_rate, - ) - ) - - in_channels = out_channels - - self.tds = nn.Sequential(*modules) - self.fc = nn.Linear(in_channels * self.input_dim // stride_h, self.output_dim) - - def forward(self, x: Tensor) -> Tensor: - """Forward pass. - - Args: - x (Tensor): Input tensor. - - Shape: - - x: :math: `(B, H, W)` - - Returns: - Tensor: Output tensor. - - """ - if len(x.shape) == 4: - x = x.squeeze(1) # Squeeze the channel dim away. - - B, H, W = x.shape - x = rearrange( - x, "b (h1 h2) w -> b h1 h2 w", h1=self.in_channels, h2=H // self.in_channels - ) - x = self.tds(x) - - # x shape: [B, C, H, W] - x = rearrange(x, "b c h w -> b w (c h)") - - return self.fc(x) diff --git a/text_recognizer/networks/transducer/test.py b/text_recognizer/networks/transducer/test.py deleted file mode 100644 index cadcecc..0000000 --- a/text_recognizer/networks/transducer/test.py +++ /dev/null @@ -1,60 +0,0 @@ -import torch -from torch import nn - -from text_recognizer.networks.transducer import load_transducer_loss, Transducer -import unittest - - -class TestTransducer(unittest.TestCase): - def test_viterbi(self): - T = 5 - N = 4 - B = 2 - - # fmt: off - emissions1 = torch.tensor(( - 0, 4, 0, 1, - 0, 2, 1, 1, - 0, 0, 0, 2, - 0, 0, 0, 2, - 8, 0, 0, 2, - ), - dtype=torch.float, - ).view(T, N) - emissions2 = torch.tensor(( - 0, 2, 1, 7, - 0, 2, 9, 1, - 0, 0, 0, 2, - 0, 0, 5, 2, - 1, 0, 0, 2, - ), - dtype=torch.float, - ).view(T, N) - # fmt: on - - # Test without blank: - labels = [[1, 3, 0], [3, 2, 3, 2, 3]] - transducer = Transducer( - tokens=["a", "b", "c", "d"], - graphemes_to_idx={"a": 0, "b": 1, "c": 2, "d": 3}, - blank="none", - ) - emissions = torch.stack([emissions1, emissions2], dim=0) - predictions = transducer.viterbi(emissions) - self.assertEqual([p.tolist() for p in predictions], labels) - - # Test with blank without repeats: - labels = [[1, 0], [2, 2]] - transducer = Transducer( - tokens=["a", "b", "c"], - graphemes_to_idx={"a": 0, "b": 1, "c": 2}, - blank="optional", - allow_repeats=False, - ) - emissions = torch.stack([emissions1, emissions2], dim=0) - predictions = transducer.viterbi(emissions) - self.assertEqual([p.tolist() for p in predictions], labels) - - -if __name__ == "__main__": - unittest.main() diff --git a/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py deleted file mode 100644 index d7e3d08..0000000 --- a/text_recognizer/networks/transducer/transducer.py +++ /dev/null @@ -1,410 +0,0 @@ -"""Transducer and the transducer loss function.py - -Stolen from: - https://github.com/facebookresearch/gtn_applications/blob/master/transducer.py - -""" -from pathlib import Path -import itertools -from typing import Dict, List, Optional, Union, Tuple - -from loguru import logger -import gtn -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.datasets.iam_preprocessor import Preprocessor - - -def make_scalar_graph(weight) -> gtn.Graph: - scalar = gtn.Graph() - scalar.add_node(True) - scalar.add_node(False, True) - scalar.add_arc(0, 1, 0, 0, weight) - return scalar - - -def make_chain_graph(sequence) -> gtn.Graph: - graph = gtn.Graph(False) - graph.add_node(True) - for i, s in enumerate(sequence): - graph.add_node(False, i == (len(sequence) - 1)) - graph.add_arc(i, i + 1, s) - return graph - - -def make_transitions_graph( - ngram: int, num_tokens: int, calc_grad: bool = False -) -> gtn.Graph: - transitions = gtn.Graph(calc_grad) - transitions.add_node(True, ngram == 1) - - state_map = {(): 0} - - # First build transitions which include : - for n in range(1, ngram): - for state in itertools.product(range(num_tokens), repeat=n): - in_idx = state_map[state[:-1]] - out_idx = transitions.add_node(False, ngram == 1) - state_map[state] = out_idx - transitions.add_arc(in_idx, out_idx, state[-1]) - - for state in itertools.product(range(num_tokens), repeat=ngram): - state_idx = state_map[state[:-1]] - new_state_idx = state_map[state[1:]] - # p(state[-1] | state[:-1]) - transitions.add_arc(state_idx, new_state_idx, state[-1]) - - if ngram > 1: - # Build transitions which include : - end_idx = transitions.add_node(False, True) - for in_idx in range(end_idx): - transitions.add_arc(in_idx, end_idx, gtn.epsilon) - - return transitions - - -def make_lexicon_graph(word_pieces: List, graphemes_to_idx: Dict) -> gtn.Graph: - """Constructs a graph which transduces letters to word pieces.""" - graph = gtn.Graph(False) - graph.add_node(True, True) - for i, wp in enumerate(word_pieces): - prev = 0 - for l in wp[:-1]: - n = graph.add_node() - graph.add_arc(prev, n, graphemes_to_idx[l], gtn.epsilon) - prev = n - graph.add_arc(prev, 0, graphemes_to_idx[wp[-1]], i) - graph.arc_sort() - return graph - - -def make_token_graph( - token_list: List, blank: str = "none", allow_repeats: bool = True -) -> gtn.Graph: - """Constructs a graph with all the individual token transition models.""" - if not allow_repeats and blank != "optional": - raise ValueError("Must use blank='optional' if disallowing repeats.") - - ntoks = len(token_list) - graph = gtn.Graph(False) - - # Creating nodes - graph.add_node(True, True) - for i in range(ntoks): - # We can consume one or more consecutive word - # pieces for each emission: - # E.g. [ab, ab, ab] transduces to [ab] - graph.add_node(False, blank != "forced") - - if blank != "none": - graph.add_node() - - # Creating arcs - if blank != "none": - # Blank index is assumed to be last (ntoks) - graph.add_arc(0, ntoks + 1, ntoks, gtn.epsilon) - graph.add_arc(ntoks + 1, 0, gtn.epsilon) - - for i in range(ntoks): - graph.add_arc((ntoks + 1) if blank == "forced" else 0, i + 1, i) - graph.add_arc(i + 1, i + 1, i, gtn.epsilon) - - if allow_repeats: - if blank == "forced": - # Allow transitions from token to blank only - graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon) - else: - # Allow transition from token to blank and all other tokens - graph.add_arc(i + 1, 0, gtn.epsilon) - - else: - # allow transitions to blank and all other tokens except the same token - graph.add_arc(i + 1, ntoks + 1, ntoks, gtn.epsilon) - for j in range(ntoks): - if i != j: - graph.add_arc(i + 1, j + 1, j, j) - - return graph - - -class TransducerLossFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - inputs, - targets, - tokens, - lexicon, - transition_params=None, - transitions=None, - reduction="none", - ) -> Tensor: - B, T, C = inputs.shape - - losses = [None] * B - emissions_graphs = [None] * B - - if transitions is not None: - if transition_params is None: - raise ValueError("Specified transitions, but not transition params.") - - cpu_data = transition_params.cpu().contiguous() - transitions.set_weights(cpu_data.data_ptr()) - transitions.calc_grad = transition_params.requires_grad - transitions.zero_grad() - - def process(b: int) -> None: - # Create emission graph: - emissions = gtn.linear_graph(T, C, inputs.requires_grad) - cpu_data = inputs[b].cpu().contiguous() - emissions.set_weights(cpu_data.data_ptr()) - target = make_chain_graph(targets[b]) - target.arc_sort(True) - - # Create token tot grapheme decomposition graph - tokens_target = gtn.remove(gtn.project_output(gtn.compose(target, lexicon))) - tokens_target.arc_sort() - - # Create alignment graph: - aligments = gtn.project_input( - gtn.remove(gtn.compose(tokens, tokens_target)) - ) - aligments.arc_sort() - - # Add transitions scores: - if transitions is not None: - aligments = gtn.intersect(transitions, aligments) - aligments.arc_sort() - - loss = gtn.forward_score(gtn.intersect(emissions, aligments)) - - # Normalize if needed: - if transitions is not None: - norm = gtn.forward_score(gtn.intersect(emissions, transitions)) - loss = gtn.subtract(loss, norm) - - losses[b] = gtn.negate(loss) - - # Save for backward: - if emissions.calc_grad: - emissions_graphs[b] = emissions - - gtn.parallel_for(process, range(B)) - - ctx.graphs = (losses, emissions_graphs, transitions) - ctx.input_shape = inputs.shape - - # Optionally reduce by target length - if reduction == "mean": - scales = [(1 / len(t) if len(t) > 0 else 1.0) for t in targets] - else: - scales = [1.0] * B - - ctx.scales = scales - - loss = torch.tensor([l.item() * s for l, s in zip(losses, scales)]) - return torch.mean(loss.to(inputs.device)) - - @staticmethod - def backward(ctx, grad_output) -> Tuple: - losses, emissions_graphs, transitions = ctx.graphs - scales = ctx.scales - - B, T, C = ctx.input_shape - calc_emissions = ctx.needs_input_grad[0] - input_grad = torch.empty((B, T, C)) if calc_emissions else None - - def process(b: int) -> None: - scale = make_scalar_graph(scales[b]) - gtn.backward(losses[b], scale) - emissions = emissions_graphs[b] - if calc_emissions: - grad = emissions.grad().weights_to_numpy() - input_grad[b] = torch.tensor(grad).view(1, T, C) - - gtn.parallel_for(process, range(B)) - - if calc_emissions: - input_grad = input_grad.to(grad_output.device) - input_grad *= grad_output / B - - if ctx.needs_input_grad[4]: - grad = transitions.grad().weights_to_numpy() - transition_grad = torch.tensor(grad).to(grad_output.device) - transition_grad *= grad_output / B - else: - transition_grad = None - - return ( - input_grad, - None, # target - None, # tokens - None, # lexicon - transition_grad, # transition params - None, # transitions graph - None, - ) - - -TransducerLoss = TransducerLossFunction.apply - - -class Transducer(nn.Module): - def __init__( - self, - tokens: List, - graphemes_to_idx: Dict, - ngram: int = 0, - transitions: str = None, - blank: str = "none", - allow_repeats: bool = True, - reduction: str = "none", - ) -> None: - """A generic transducer loss function. - - Args: - tokens (List) : A list of iterable objects (e.g. strings, tuples, etc) - representing the output tokens of the model (e.g. letters, - word-pieces, words). For example ["a", "b", "ab", "ba", "aba"] - could be a list of sub-word tokens. - graphemes_to_idx (dict) : A dictionary mapping grapheme units (e.g. - "a", "b", ..) to their corresponding integer index. - ngram (int) : Order of the token-level transition model. If `ngram=0` - then no transition model is used. - blank (string) : Specifies the usage of blank token - 'none' - do not use blank token - 'optional' - allow an optional blank inbetween tokens - 'forced' - force a blank inbetween tokens (also referred to as garbage token) - allow_repeats (boolean) : If false, then we don't allow paths with - consecutive tokens in the alignment graph. This keeps the graph - unambiguous in the sense that the same input cannot transduce to - different outputs. - """ - super().__init__() - if blank not in ["optional", "forced", "none"]: - raise ValueError( - "Invalid value specified for blank. Must be in ['optional', 'forced', 'none']" - ) - self.tokens = make_token_graph(tokens, blank=blank, allow_repeats=allow_repeats) - self.lexicon = make_lexicon_graph(tokens, graphemes_to_idx) - self.ngram = ngram - if ngram > 0 and transitions is not None: - raise ValueError("Only one of ngram and transitions may be specified") - - if ngram > 0: - transitions = make_transitions_graph( - ngram, len(tokens) + int(blank != "none"), True - ) - - if transitions is not None: - self.transitions = transitions - self.transitions.arc_sort() - self.transitions_params = nn.Parameter( - torch.zeros(self.transitions.num_arcs()) - ) - else: - self.transitions = None - self.transitions_params = None - self.reduction = reduction - - def forward(self, inputs: Tensor, targets: Tensor) -> TransducerLoss: - TransducerLoss( - inputs, - targets, - self.tokens, - self.lexicon, - self.transitions_params, - self.transitions, - self.reduction, - ) - - def viterbi(self, outputs: Tensor) -> List[Tensor]: - B, T, C = outputs.shape - - if self.transitions is not None: - cpu_data = self.transition_params.cpu().contiguous() - self.transitions.set_weights(cpu_data.data_ptr()) - self.transitions.calc_grad = False - - self.tokens.arc_sort() - - paths = [None] * B - - def process(b: int) -> None: - emissions = gtn.linear_graph(T, C, False) - cpu_data = outputs[b].cpu().contiguous() - emissions.set_weights(cpu_data.data_ptr()) - - if self.transitions is not None: - full_graph = gtn.intersect(emissions, self.transitions) - else: - full_graph = emissions - - # Find the best path and remove back-off arcs: - path = gtn.remove(gtn.viterbi_path(full_graph)) - - # Left compose the viterbi path with the "aligment to token" - # transducer to get the outputs: - path = gtn.compose(path, self.tokens) - - # When there are ambiguous paths (allow_repeats is true), we take - # the shortest: - path = gtn.viterbi_path(path) - path = gtn.remove(gtn.project_output(path)) - paths[b] = path.labels_to_list() - - gtn.parallel_for(process, range(B)) - predictions = [torch.IntTensor(path) for path in paths] - return predictions - - -def load_transducer_loss( - num_features: int, - ngram: int, - tokens: str, - lexicon: str, - transitions: str, - blank: str, - allow_repeats: bool, - prepend_wordsep: bool = False, - use_words: bool = False, - data_dir: Optional[Union[str, Path]] = None, - reduction: str = "mean", -) -> Tuple[Transducer, int]: - if data_dir is None: - data_dir = ( - Path(__file__).resolve().parents[4] / "data" / "raw" / "iam" / "iamdb" - ) - logger.debug(f"Using data dir: {data_dir}") - if not data_dir.exists(): - raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") - else: - data_dir = Path(data_dir) - processed_path = ( - Path(__file__).resolve().parents[4] / "data" / "processed" / "iam_lines" - ) - tokens_path = processed_path / tokens - lexicon_path = processed_path / lexicon - - if transitions is not None: - transitions = gtn.load(str(processed_path / transitions)) - - preprocessor = Preprocessor( - data_dir, num_features, tokens_path, lexicon_path, use_words, prepend_wordsep, - ) - - num_tokens = preprocessor.num_tokens - - criterion = Transducer( - preprocessor.tokens, - preprocessor.graphemes_to_index, - ngram=ngram, - transitions=transitions, - blank=blank, - allow_repeats=allow_repeats, - reduction=reduction, - ) - - return criterion, num_tokens + int(blank != "none") diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py index a3f3011..d9e63ef 100644 --- a/text_recognizer/networks/transformer/__init__.py +++ b/text_recognizer/networks/transformer/__init__.py @@ -1 +1,3 @@ """Transformer modules.""" +from .nystromer.nystromer import Nystromer +from .vit import ViT diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index b2c703f..a44a525 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -1,8 +1,6 @@ """Generates the attention layer architecture.""" from functools import partial -from typing import Any, Dict, Optional, Type - -from click.types import Tuple +from typing import Any, Dict, Optional, Tuple, Type from torch import nn, Tensor @@ -30,6 +28,7 @@ class AttentionLayers(nn.Module): pre_norm: bool = True, ) -> None: super().__init__() + self.dim = dim attn_fn = partial(attn_fn, dim=dim, num_heads=num_heads, **attn_kwargs) norm_fn = partial(norm_fn, dim) ff_fn = partial(ff_fn, dim=dim, **ff_kwargs) diff --git a/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py b/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py index 9466f6e..7140537 100644 --- a/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py +++ b/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py @@ -1,4 +1,5 @@ """Absolute positional embedding.""" +import torch from torch import nn, Tensor diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py index 60ab1ce..31088b4 100644 --- a/text_recognizer/networks/transformer/transformer.py +++ b/text_recognizer/networks/transformer/transformer.py @@ -19,7 +19,9 @@ class Transformer(nn.Module): emb_dropout: float = 0.0, use_pos_emb: bool = True, ) -> None: + super().__init__() dim = attn_layers.dim + self.attn_layers = attn_layers emb_dim = emb_dim if emb_dim is not None else dim self.max_seq_len = max_seq_len @@ -32,7 +34,6 @@ class Transformer(nn.Module): ) self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity() - self.attn_layers = attn_layers self.norm = nn.LayerNorm(dim) self._init_weights() @@ -45,12 +46,12 @@ class Transformer(nn.Module): def forward( self, x: Tensor, - mask: Optional[Tensor], + mask: Optional[Tensor] = None, return_embeddings: bool = False, **kwargs: Any ) -> Tensor: b, n, device = *x.shape, x.device - x += self.token_emb(x) + x = self.token_emb(x) if self.pos_emb is not None: x += self.pos_emb(x) x = self.emb_dropout(x) diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py index 9c6b151..05b10a8 100644 --- a/text_recognizer/networks/util.py +++ b/text_recognizer/networks/util.py @@ -22,42 +22,3 @@ def activation_function(activation: str) -> Type[nn.Module]: ] ) return activation_fns[activation.lower()] - - -# def configure_backbone(backbone: Union[OmegaConf, NamedTuple]) -> Type[nn.Module]: -# """Loads a backbone network.""" -# network_module = importlib.import_module("text_recognizer.networks") -# backbone_class = getattr(network_module, backbone.type) -# -# if "pretrained" in backbone.args: -# logger.info("Loading pretrained backbone.") -# checkpoint_file = Path(__file__).resolve().parents[2] / backbone.args.pop( -# "pretrained" -# ) -# -# # Loading state directory. -# state_dict = torch.load(checkpoint_file) -# network_args = state_dict["network_args"] -# weights = state_dict["model_state"] -# -# freeze = False -# if "freeze" in backbone.args and backbone.args["freeze"] is True: -# backbone.args.pop("freeze") -# freeze = True -# -# # Initializes the network with trained weights. -# backbone_ = backbone_(**backbone.args) -# backbone_.load_state_dict(weights) -# if freeze: -# for params in backbone_.parameters(): -# params.requires_grad = False -# else: -# backbone_ = getattr(network_module, backbone.type) -# backbone_ = backbone_(**backbone.args) -# -# if "remove_layers" in backbone_args and backbone_args["remove_layers"] is not None: -# backbone = nn.Sequential( -# *list(backbone.children())[:][: -backbone_args["remove_layers"]] -# ) -# -# return backbone -- cgit v1.2.3-70-g09d2