diff options
Diffstat (limited to 'text_recognizer/networks/cnn_transformer.py')
-rw-r--r-- | text_recognizer/networks/cnn_transformer.py | 257 |
1 files changed, 132 insertions, 125 deletions
diff --git a/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py index 9150b55..e23a15d 100644 --- a/text_recognizer/networks/cnn_transformer.py +++ b/text_recognizer/networks/cnn_transformer.py @@ -1,158 +1,165 @@ -"""A CNN-Transformer for image to text recognition.""" -from typing import Dict, Optional, Tuple +"""A Transformer with a cnn backbone. + +The network encodes a image with a convolutional backbone to a latent representation, +i.e. feature maps. A 2d positional encoding is applied to the feature maps for +spatial information. The resulting feature are then set to a transformer decoder +together with the target tokens. + +TODO: Local attention for lower layer in attention. + +""" +import importlib +import math +from typing import Dict, Optional, Union, Sequence, Type from einops import rearrange +from omegaconf import DictConfig, OmegaConf import torch from torch import nn from torch import Tensor -from text_recognizer.networks.transformer import PositionalEncoding, Transformer -from text_recognizer.networks.util import activation_function -from text_recognizer.networks.util import configure_backbone +from text_recognizer.data.emnist import NUM_SPECIAL_TOKENS +from text_recognizer.networks.transformer import ( + Decoder, + DecoderLayer, + PositionalEncoding, + PositionalEncoding2D, + target_padding_mask, +) +NUM_WORD_PIECES = 1000 -class CNNTransformer(nn.Module): - """CNN+Transfomer for image to sequence prediction.""" +class CNNTransformer(nn.Module): def __init__( self, - num_encoder_layers: int, - num_decoder_layers: int, - hidden_dim: int, - vocab_size: int, - num_heads: int, - adaptive_pool_dim: Tuple, - expansion_dim: int, - dropout_rate: float, - trg_pad_index: int, - max_len: int, - backbone: str, - backbone_args: Optional[Dict] = None, - activation: str = "gelu", - pool_kernel: Optional[Tuple[int, int]] = None, + input_shape: Sequence[int], + output_shape: Sequence[int], + encoder: Union[DictConfig, Dict], + vocab_size: Optional[int] = None, + num_decoder_layers: int = 4, + hidden_dim: int = 256, + num_heads: int = 4, + expansion_dim: int = 1024, + dropout_rate: float = 0.1, + transformer_activation: str = "glu", ) -> None: - super().__init__() - self.trg_pad_index = trg_pad_index - self.vocab_size = vocab_size - self.backbone = configure_backbone(backbone, backbone_args) - - if pool_kernel is not None: - self.max_pool = nn.MaxPool2d(pool_kernel, stride=2) - else: - self.max_pool = None - - self.character_embedding = nn.Embedding(self.vocab_size, hidden_dim) - - self.src_position_embedding = nn.Parameter(torch.randn(1, max_len, hidden_dim)) - self.pos_dropout = nn.Dropout(p=dropout_rate) - self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) - - nn.init.normal_(self.character_embedding.weight, std=0.02) - - self.adaptive_pool = ( - nn.AdaptiveAvgPool2d((adaptive_pool_dim)) if adaptive_pool_dim else None + self.vocab_size = ( + NUM_WORD_PIECES + NUM_SPECIAL_TOKENS if vocab_size is None else vocab_size ) + self.hidden_dim = hidden_dim + self.max_output_length = output_shape[0] - self.transformer = Transformer( - num_encoder_layers, - num_decoder_layers, - hidden_dim, - num_heads, - expansion_dim, - dropout_rate, - activation, + # Image backbone + self.encoder = self._configure_encoder(encoder) + self.feature_map_encoding = PositionalEncoding2D( + hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2] ) - self.head = nn.Sequential( - # nn.Linear(hidden_dim, hidden_dim * 2), - # activation_function(activation), - nn.Linear(hidden_dim, vocab_size), - ) + # Target token embedding + self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) + self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) - def _create_trg_mask(self, trg: Tensor) -> Tensor: - # Move this outside the transformer. - trg_pad_mask = (trg != self.trg_pad_index)[:, None, None] - trg_len = trg.shape[1] - trg_sub_mask = torch.tril( - torch.ones((trg_len, trg_len), device=trg.device) - ).bool() - trg_mask = trg_pad_mask & trg_sub_mask - return trg_mask - - def encoder(self, src: Tensor) -> Tensor: - """Forward pass with the encoder of the transformer.""" - return self.transformer.encoder(src) - - def decoder(self, trg: Tensor, memory: Tensor, trg_mask: Tensor) -> Tensor: - """Forward pass with the decoder of the transformer + classification head.""" - return self.head( - self.transformer.decoder(trg=trg, memory=memory, trg_mask=trg_mask) + # Transformer decoder + self.decoder = Decoder( + decoder_layer=DecoderLayer( + hidden_dim=hidden_dim, + num_heads=num_heads, + expansion_dim=expansion_dim, + dropout_rate=dropout_rate, + activation=transformer_activation, + ), + num_layers=num_decoder_layers, + norm=nn.LayerNorm(hidden_dim), ) - def extract_image_features(self, src: Tensor) -> Tensor: - """Extracts image features with a backbone neural network. - - It seem like the winning idea was to swap channels and width dimension and collapse - the height dimension. The transformer is learning like a baby with this implementation!!! :D - Ohhhh, the joy I am experiencing right now!! Bring in the beers! :D :D :D + # Classification head + self.head = nn.Linear(hidden_dim, self.vocab_size) - Args: - src (Tensor): Input tensor. + # Initialize weights + self._init_weights() - Returns: - Tensor: A input src to the transformer. + def _init_weights(self) -> None: + """Initialize network weights.""" + self.trg_embedding.weight.data.uniform_(-0.1, 0.1) + self.head.bias.data.zero_() + self.head.weight.data.uniform_(-0.1, 0.1) - """ - # If batch dimension is missing, it needs to be added. - if len(src.shape) < 4: - src = src[(None,) * (4 - len(src.shape))] - - src = self.backbone(src) - - if self.max_pool is not None: - src = self.max_pool(src) - - if self.adaptive_pool is not None and len(src.shape) == 4: - src = rearrange(src, "b c h w -> b w c h") - src = self.adaptive_pool(src) - src = src.squeeze(3) - elif len(src.shape) == 4: - src = rearrange(src, "b c h w -> b (h w) c") + nn.init.kaiming_normal_( + self.feature_map_encoding.weight.data, + a=0, + mode="fan_out", + nonlinearity="relu", + ) + if self.feature_map_encoding.bias is not None: + _, fan_out = nn.init._calculate_fan_in_and_fan_out( + self.feature_map_encoding.weight.data + ) + bound = 1 / math.sqrt(fan_out) + nn.init.normal_(self.feature_map_encoding.bias, -bound, bound) + + @staticmethod + def _configure_encoder(encoder: Union[DictConfig, Dict]) -> Type[nn.Module]: + encoder = OmegaConf.create(encoder) + network_module = importlib.import_module("text_recognizer.networks") + encoder_class = getattr(network_module, encoder.type) + return encoder_class(**encoder.args) + + def encode(self, image: Tensor) -> Tensor: + """Extracts image features with backbone. - b, t, _ = src.shape + Args: + image (Tensor): Image(s) of handwritten text. - src += self.src_position_embedding[:, :t] - src = self.pos_dropout(src) + Retuns: + Tensor: Image features. - return src + Shapes: + - image: :math: `(B, C, H, W)` + - latent: :math: `(B, T, C)` - def target_embedding(self, trg: Tensor) -> Tuple[Tensor, Tensor]: - """Encodes target tensor with embedding and postion. + """ + # Extract image features. + image_features = self.encoder(image) - Args: - trg (Tensor): Target tensor. + # Add 2d encoding to the feature maps. + image_features = self.feature_map_encoding(image_features) - Returns: - Tuple[Tensor, Tensor]: Encoded target tensor and target mask. + # Collapse features maps height and width. + image_features = rearrange(image_features, "b c h w -> b (h w) c") + return image_features - """ - trg = self.character_embedding(trg.long()) + def decode(self, memory: Tensor, trg: Tensor) -> Tensor: + """Decodes image features with transformer decoder.""" + trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) + trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) trg = self.trg_position_encoding(trg) - return trg - - def decode_image_features( - self, image_features: Tensor, trg: Optional[Tensor] = None - ) -> Tensor: - """Takes images features from the backbone and decodes them with the transformer.""" - trg_mask = self._create_trg_mask(trg) - trg = self.target_embedding(trg) - out = self.transformer(image_features, trg, trg_mask=trg_mask) - + out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None) logits = self.head(out) return logits - def forward(self, x: Tensor, trg: Optional[Tensor] = None) -> Tensor: - """Forward pass with CNN transfomer.""" - image_features = self.extract_image_features(x) - logits = self.decode_image_features(image_features, trg) - return logits + def predict(self, image: Tensor) -> Tensor: + """Transcribes text in image(s).""" + bsz = image.shape[0] + image_features = self.encode(image) + + output_tokens = ( + (torch.ones((bsz, self.max_output_length)) * self.pad_index) + .type_as(image) + .long() + ) + output_tokens[:, 0] = self.start_index + for i in range(1, self.max_output_length): + trg = output_tokens[:, :i] + output = self.decode(image_features, trg) + output = torch.argmax(output, dim=-1) + output_tokens[:, i] = output[-1:] + + # Set all tokens after end token to be padding. + for i in range(1, self.max_output_length): + indices = output_tokens[:, i - 1] == self.end_index | ( + output_tokens[:, i - 1] == self.pad_index + ) + output_tokens[indices, i] = self.pad_index + + return output_tokens |