diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-04 23:08:16 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-04-04 23:08:16 +0200 |
commit | 2d4714fcfeb8914f240a0d36d938b434e82f191b (patch) | |
tree | 32e7b3446332cee4685ec90870210c51f9f1279f /text_recognizer/networks | |
parent | 5dc8a7097ab6b4f39f0a3add408e3fd0f131f85b (diff) |
Add new transformer network
Diffstat (limited to 'text_recognizer/networks')
4 files changed, 174 insertions, 4 deletions
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py new file mode 100644 index 0000000..5a093dc --- /dev/null +++ b/text_recognizer/networks/image_transformer.py @@ -0,0 +1,159 @@ +"""A Transformer with a cnn backbone. + +The network encodes a image with a convolutional backbone to a latent representation, +i.e. feature maps. A 2d positional encoding is applied to the feature maps for +spatial information. The resulting feature are then set to a transformer decoder +together with the target tokens. + +TODO: Local attention for transformer.j + +""" +import math +from typing import Any, Dict, List, Optional, Sequence, Type + +from einops import rearrange +import torch +from torch import nn +from torch import Tensor +import torchvision + +from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.networks.transformer import ( + Decoder, + DecoderLayer, + PositionalEncoding, + PositionalEncoding2D, + target_padding_mask, +) + + +class ImageTransformer(nn.Module): + def __init__( + self, + input_shape: Sequence[int], + output_shape: Sequence[int], + backbone: Type[nn.Module], + mapping: Optional[List[str]] = None, + num_decoder_layers: int = 4, + hidden_dim: int = 256, + num_heads: int = 4, + expansion_dim: int = 4, + dropout_rate: float = 0.1, + transformer_activation: str = "glu", + ) -> None: + # Configure mapping + mapping, inverse_mapping = self._configure_mapping(mapping) + self.vocab_size = len(mapping) + self.hidden_dim = hidden_dim + self.max_output_length = output_shape[0] + self.start_index = inverse_mapping["<s>"] + self.end_index = inverse_mapping["<e>"] + self.pad_index = inverse_mapping["<p>"] + + # Image backbone + self.backbone = backbone + self.latent_encoding = PositionalEncoding2D(hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2]) + + # Target token embedding + self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) + self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) + + # Transformer decoder + self.decoder = Decoder( + decoder_layer=DecoderLayer( + hidden_dim=hidden_dim, + num_heads=num_heads, + expansion_dim=expansion_dim, + dropout_rate=dropout_rate, + activation=transformer_activation, + ), + num_layers=num_decoder_layers, + norm=nn.LayerNorm(hidden_dim), + ) + + # Classification head + self.head = nn.Linear(hidden_dim, self.vocab_size) + + # Initialize weights + self._init_weights() + + def _init_weights(self) -> None: + """Initialize network weights.""" + self.trg_embedding.weight.data.uniform_(-0.1, 0.1) + self.head.bias.data.zero_() + self.head.weight.data.uniform_(-0.1, 0.1) + + nn.init.kaiming_normal_(self.latent_encoding.weight.data, a=0, mode="fan_out", nonlinearity="relu") + if self.latent_encoding.bias is not None: + _, fan_out = nn.init._calculate_fan_in_and_fan_out(self.latent_encoding.weight.data) + bound = 1 / math.sqrt(fan_out) + nn.init.normal_(self.latent_encoding.bias, -bound, bound) + + def _configure_mapping(self, mapping: Optional[List[str]]) -> Tuple[List[str], Dict[str, int]]: + """Configures mapping.""" + if mapping is None: + mapping, inverse_mapping, _ = emnist_mapping() + return mapping, inverse_mapping + + def encode(self, image: Tensor) -> Tensor: + """Extracts image features with backbone. + + Args: + image (Tensor): Image(s) of handwritten text. + + Retuns: + Tensor: Image features. + + Shapes: + - image: :math: `(B, C, H, W)` + - latent: :math: `(B, T, C)` + + """ + # Extract image features. + latent = self.backbone(image) + + # Add 2d encoding to the feature maps. + latent = self.latent_encoding(latent) + + # Collapse features maps height and width. + latent = rearrange(latent, "b c h w -> b (h w) c") + return latent + + def decode(self, memory: Tensor, trg: Tensor) -> Tensor: + """Decodes image features with transformer decoder.""" + trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) + trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) + trg = self.trg_position_encoding(trg) + out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None) + logits = self.head(out) + return logits + + def predict(self, image: Tensor) -> Tensor: + """Transcribes text in image(s).""" + bsz = image.shape[0] + image_features = self.encode(image) + + output_tokens = (torch.ones((bsz, self.max_output_length)) * self.pad_index).type_as(image).long() + output_tokens[:, 0] = self.start_index + for i in range(1, self.max_output_length): + trg = output_tokens[:, :i] + output = self.decode(image_features, trg) + output = torch.argmax(output, dim=-1) + output_tokens[:, i] = output[-1:] + + # Set all tokens after end token to be padding. + for i in range(1, self.max_output_length): + indices = (output_tokens[:, i - 1] == self.end_index | (output_tokens[:, i - 1] == self.pad_index)) + output_tokens[indices, i] = self.pad_index + + return output_tokens + + + + + + + + + + diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py index 9febc88..139cd23 100644 --- a/text_recognizer/networks/transformer/__init__.py +++ b/text_recognizer/networks/transformer/__init__.py @@ -1,3 +1,3 @@ """Transformer modules.""" -from .positional_encoding import PositionalEncoding +from .positional_encoding import PositionalEncoding, PositionalEncoding2D, target_padding_mask from .transformer import Decoder, Encoder, EncoderLayer, Transformer diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index cce1ecc..ac75d2f 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -50,8 +50,9 @@ class MultiHeadAttention(nn.Module): ) nn.init.xavier_normal_(self.fc_out.weight) + @staticmethod def scaled_dot_product_attention( - self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None + query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None ) -> Tensor: """Calculates the scaled dot product attention.""" diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py index d67d297..dbde887 100644 --- a/text_recognizer/networks/transformer/positional_encoding.py +++ b/text_recognizer/networks/transformer/positional_encoding.py @@ -56,9 +56,9 @@ class PositionalEncoding2D(nn.Module): pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w) pe_w = PositionalEncoding.make_pe( - hidden_dim // 2, max_len=max_h + hidden_dim // 2, max_len=max_w ) # [W, 1, D // 2] - pe_w = repeat(pe_w, "h w d -> d (h tile) w", tile=max_h) + pe_w = repeat(pe_w, "w h d -> d (h tile) w", tile=max_h) pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W] return pe @@ -70,3 +70,13 @@ class PositionalEncoding2D(nn.Module): raise ValueError("Hidden dimensions does not match.") x += self.pe[:, : x.shape[2], : x.shape[3]] return x + +def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor: + """Returns causal target mask.""" + trg_pad_mask = (trg != pad_index)[:, None, None] + trg_len = trg.shape[1] + trg_sub_mask = torch.tril( + torch.ones((trg_len, trg_len), device=trg.device) + ).bool() + trg_mask = trg_pad_mask & trg_sub_mask + return trg_mask |