summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/cnn_transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/cnn_transformer.py')
-rw-r--r--text_recognizer/networks/cnn_transformer.py257
1 files changed, 132 insertions, 125 deletions
diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py
index 9150b55..e23a15d 100644
--- a/text_recognizer/networks/cnn_transformer.py
+++ b/text_recognizer/networks/cnn_transformer.py
@@ -1,158 +1,165 @@
-"""A CNN-Transformer for image to text recognition."""
-from typing import Dict, Optional, Tuple
+"""A Transformer with a cnn backbone.
+
+The network encodes a image with a convolutional backbone to a latent representation,
+i.e. feature maps. A 2d positional encoding is applied to the feature maps for
+spatial information. The resulting feature are then set to a transformer decoder
+together with the target tokens.
+
+TODO: Local attention for lower layer in attention.
+
+"""
+import importlib
+import math
+from typing import Dict, Optional, Union, Sequence, Type
from einops import rearrange
+from omegaconf import DictConfig, OmegaConf
import torch
from torch import nn
from torch import Tensor
-from text_recognizer.networks.transformer import PositionalEncoding, Transformer
-from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.util import configure_backbone
+from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS
+from text_recognizer.networks.transformer import (
+ Decoder,
+ DecoderLayer,
+ PositionalEncoding,
+ PositionalEncoding2D,
+ target_padding_mask,
+)
+NUM_WORD_PIECES = 1000
-class CNNTransformer(nn.Module):
- """CNN+Transfomer for image to sequence prediction."""
+class CNNTransformer(nn.Module):
def __init__(
self,
- num_encoder_layers: int,
- num_decoder_layers: int,
- hidden_dim: int,
- vocab_size: int,
- num_heads: int,
- adaptive_pool_dim: Tuple,
- expansion_dim: int,
- dropout_rate: float,
- trg_pad_index: int,
- max_len: int,
- backbone: str,
- backbone_args: Optional[Dict] = None,
- activation: str = "gelu",
- pool_kernel: Optional[Tuple[int, int]] = None,
+ input_shape: Sequence[int],
+ output_shape: Sequence[int],
+ encoder: Union[DictConfig, Dict],
+ vocab_size: Optional[int] = None,
+ num_decoder_layers: int = 4,
+ hidden_dim: int = 256,
+ num_heads: int = 4,
+ expansion_dim: int = 1024,
+ dropout_rate: float = 0.1,
+ transformer_activation: str = "glu",
) -> None:
- super().__init__()
- self.trg_pad_index = trg_pad_index
- self.vocab_size = vocab_size
- self.backbone = configure_backbone(backbone, backbone_args)
-
- if pool_kernel is not None:
- self.max_pool = nn.MaxPool2d(pool_kernel, stride=2)
- else:
- self.max_pool = None
-
- self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim)
-
- self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim))
- self.pos_dropout = nn.Dropout(p=dropout_rate)
- self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
-
- nn.init.normal_(self.character_embedding.weight, std=0.02)
-
- self.adaptive_pool = (
- nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None
+ self.vocab_size = (
+ NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size
)
+ self.hidden_dim = hidden_dim
+ self.max_output_length = output_shape[0]
- self.transformer = Transformer(
- num_encoder_layers,
- num_decoder_layers,
- hidden_dim,
- num_heads,
- expansion_dim,
- dropout_rate,
- activation,
+ # Image backbone
+ self.encoder = self._configure_encoder(encoder)
+ self.feature_map_encoding = PositionalEncoding2D(
+ hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2]
)
- self.head = nn.Sequential(
- # nn.Linear(hidden_dim, hidden_dim * 2),
- # activation_function(activation),
- nn.Linear(hidden_dim, vocab_size),
- )
+ # Target token embedding
+ self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim)
+ self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
- def _create_trg_mask(self, trg: Tensor) -> Tensor:
- # Move this outside the transformer.
- trg_pad_mask = (trg != self.trg_pad_index)[:, None, None]
- trg_len = trg.shape[1]
- trg_sub_mask = torch.tril(
- torch.ones((trg_len, trg_len), device=trg.device)
- ).bool()
- trg_mask = trg_pad_mask & trg_sub_mask
- return trg_mask
-
- def encoder(self, src: Tensor) -> Tensor:
- """Forward pass with the encoder of the transformer."""
- return self.transformer.encoder(src)
-
- def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor:
- """Forward pass with the decoder of the transformer + classification head."""
- return self.head(
- self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask)
+ # Transformer decoder
+ self.decoder = Decoder(
+ decoder_layer=DecoderLayer(
+ hidden_dim=hidden_dim,
+ num_heads=num_heads,
+ expansion_dim=expansion_dim,
+ dropout_rate=dropout_rate,
+ activation=transformer_activation,
+ ),
+ num_layers=num_decoder_layers,
+ norm=nn.LayerNorm(hidden_dim),
)
- def extract_image_features(self, src: Tensor) -> Tensor:
- """Extracts image features with a backbone neural network.
-
- It seem like the winning idea was to swap channels and width dimension and collapse
- the height dimension. The transformer is learning like a baby with this implementation!!! :D
- Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D
+ # Classification head
+ self.head = nn.Linear(hidden_dim, self.vocab_size)
- Args:
- src (Tensor): Input tensor.
+ # Initialize weights
+ self._init_weights()
- Returns:
- Tensor: A input src to the transformer.
+ def _init_weights(self) -> None:
+ """Initialize network weights."""
+ self.trg_embedding.weight.data.uniform_(-0.1, 0.1)
+ self.head.bias.data.zero_()
+ self.head.weight.data.uniform_(-0.1, 0.1)
- """
- # If batch dimension is missing, it needs to be added.
- if len(src.shape) < 4:
- src = src[(None,) * (4 - len(src.shape))]
-
- src = self.backbone(src)
-
- if self.max_pool is not None:
- src = self.max_pool(src)
-
- if self.adaptive_pool is not None and len(src.shape) == 4:
- src = rearrange(src, "b c h w -> b w c h")
- src = self.adaptive_pool(src)
- src = src.squeeze(3)
- elif len(src.shape) == 4:
- src = rearrange(src, "b c h w -> b (h w) c")
+ nn.init.kaiming_normal_(
+ self.feature_map_encoding.weight.data,
+ a=0,
+ mode="fan_out",
+ nonlinearity="relu",
+ )
+ if self.feature_map_encoding.bias is not None:
+ _, fan_out = nn.init._calculate_fan_in_and_fan_out(
+ self.feature_map_encoding.weight.data
+ )
+ bound = 1 / math.sqrt(fan_out)
+ nn.init.normal_(self.feature_map_encoding.bias, -bound, bound)
+
+ @staticmethod
+ def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]:
+ encoder = OmegaConf.create(encoder)
+ network_module = importlib.import_module("text_recognizer.networks")
+ encoder_class = getattr(network_module, encoder.type)
+ return encoder_class(**encoder.args)
+
+ def encode(self, image: Tensor) -> Tensor:
+ """Extracts image features with backbone.
- b, t, _ = src.shape
+ Args:
+ image (Tensor): Image(s) of handwritten text.
- src += self.src_position_embedding[:, :t]
- src = self.pos_dropout(src)
+ Retuns:
+ Tensor: Image features.
- return src
+ Shapes:
+ - image: :math: `(B, C, H, W)`
+ - latent: :math: `(B, T, C)`
- def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]:
- """Encodes target tensor with embedding and postion.
+ """
+ # Extract image features.
+ image_features = self.encoder(image)
- Args:
- trg (Tensor): Target tensor.
+ # Add 2d encoding to the feature maps.
+ image_features = self.feature_map_encoding(image_features)
- Returns:
- Tuple[Tensor, Tensor]: Encoded target tensor and target mask.
+ # Collapse features maps height and width.
+ image_features = rearrange(image_features, "b c h w -> b (h w) c")
+ return image_features
- """
- trg = self.character_embedding(trg.long())
+ def decode(self, memory: Tensor, trg: Tensor) -> Tensor:
+ """Decodes image features with transformer decoder."""
+ trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index)
+ trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim)
trg = self.trg_position_encoding(trg)
- return trg
-
- def decode_image_features(
- self, image_features: Tensor, trg: Optional[Tensor] = None
- ) -> Tensor:
- """Takes images features from the backbone and decodes them with the transformer."""
- trg_mask = self._create_trg_mask(trg)
- trg = self.target_embedding(trg)
- out = self.transformer(image_features, trg, trg_mask=trg_mask)
-
+ out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None)
logits = self.head(out)
return logits
- def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor:
- """Forward pass with CNN transfomer."""
- image_features = self.extract_image_features(x)
- logits = self.decode_image_features(image_features, trg)
- return logits
+ def predict(self, image: Tensor) -> Tensor:
+ """Transcribes text in image(s)."""
+ bsz = image.shape[0]
+ image_features = self.encode(image)
+
+ output_tokens = (
+ (torch.ones((bsz, self.max_output_length)) * self.pad_index)
+ .type_as(image)
+ .long()
+ )
+ output_tokens[:, 0] = self.start_index
+ for i in range(1, self.max_output_length):
+ trg = output_tokens[:, :i]
+ output = self.decode(image_features, trg)
+ output = torch.argmax(output, dim=-1)
+ output_tokens[:, i] = output[-1:]
+
+ # Set all tokens after end token to be padding.
+ for i in range(1, self.max_output_length):
+ indices = output_tokens[:, i - 1] == self.end_index | (
+ output_tokens[:, i - 1] == self.pad_index
+ )
+ output_tokens[indices, i] = self.pad_index
+
+ return output_tokens