summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/image_transformer.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 23:08:16 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-04 23:08:16 +0200
commit2d4714fcfeb8914f240a0d36d938b434e82f191b (patch)
tree32e7b3446332cee4685ec90870210c51f9f1279f /text_recognizer/networks/image_transformer.py
parent5dc8a7097ab6b4f39f0a3add408e3fd0f131f85b (diff)
Add new transformer network
Diffstat (limited to 'text_recognizer/networks/image_transformer.py')
-rw-r--r--text_recognizer/networks/image_transformer.py159
1 files changed, 159 insertions, 0 deletions
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py
new file mode 100644
index 0000000..5a093dc
--- /dev/null
+++ b/text_recognizer/networks/image_transformer.py
@@ -0,0 +1,159 @@
+"""A Transformer with a cnn backbone.
+
+The network encodes a image with a convolutional backbone to a latent representation,
+i.e. feature maps. A 2d positional encoding is applied to the feature maps for
+spatial information. The resulting feature are then set to a transformer decoder
+together with the target tokens.
+
+TODO: Local attention for transformer.j
+
+"""
+import math
+from typing import Any, Dict, List, Optional, Sequence, Type
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+import torchvision
+
+from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.networks.transformer import (
+ Decoder,
+ DecoderLayer,
+ PositionalEncoding,
+ PositionalEncoding2D,
+ target_padding_mask,
+)
+
+
+class ImageTransformer(nn.Module):
+ def __init__(
+ self,
+ input_shape: Sequence[int],
+ output_shape: Sequence[int],
+ backbone: Type[nn.Module],
+ mapping: Optional[List[str]] = None,
+ num_decoder_layers: int = 4,
+ hidden_dim: int = 256,
+ num_heads: int = 4,
+ expansion_dim: int = 4,
+ dropout_rate: float = 0.1,
+ transformer_activation: str = "glu",
+ ) -> None:
+ # Configure mapping
+ mapping, inverse_mapping = self._configure_mapping(mapping)
+ self.vocab_size = len(mapping)
+ self.hidden_dim = hidden_dim
+ self.max_output_length = output_shape[0]
+ self.start_index = inverse_mapping["<s>"]
+ self.end_index = inverse_mapping["<e>"]
+ self.pad_index = inverse_mapping["<p>"]
+
+ # Image backbone
+ self.backbone = backbone
+ self.latent_encoding = PositionalEncoding2D(hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2])
+
+ # Target token embedding
+ self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim)
+ self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+
+ # Transformer decoder
+ self.decoder = Decoder(
+ decoder_layer=DecoderLayer(
+ hidden_dim=hidden_dim,
+ num_heads=num_heads,
+ expansion_dim=expansion_dim,
+ dropout_rate=dropout_rate,
+ activation=transformer_activation,
+ ),
+ num_layers=num_decoder_layers,
+ norm=nn.LayerNorm(hidden_dim),
+ )
+
+ # Classification head
+ self.head = nn.Linear(hidden_dim, self.vocab_size)
+
+ # Initialize weights
+ self._init_weights()
+
+ def _init_weights(self) -> None:
+ """Initialize network weights."""
+ self.trg_embedding.weight.data.uniform_(-0.1, 0.1)
+ self.head.bias.data.zero_()
+ self.head.weight.data.uniform_(-0.1, 0.1)
+
+ nn.init.kaiming_normal_(self.latent_encoding.weight.data, a=0, mode="fan_out", nonlinearity="relu")
+ if self.latent_encoding.bias is not None:
+ _, fan_out = nn.init._calculate_fan_in_and_fan_out(self.latent_encoding.weight.data)
+ bound = 1 / math.sqrt(fan_out)
+ nn.init.normal_(self.latent_encoding.bias, -bound, bound)
+
+ def _configure_mapping(self, mapping: Optional[List[str]]) -> Tuple[List[str], Dict[str, int]]:
+ """Configures mapping."""
+ if mapping is None:
+ mapping, inverse_mapping, _ = emnist_mapping()
+ return mapping, inverse_mapping
+
+ def encode(self, image: Tensor) -> Tensor:
+ """Extracts image features with backbone.
+
+ Args:
+ image (Tensor): Image(s) of handwritten text.
+
+ Retuns:
+ Tensor: Image features.
+
+ Shapes:
+ - image: :math: `(B, C, H, W)`
+ - latent: :math: `(B, T, C)`
+
+ """
+ # Extract image features.
+ latent = self.backbone(image)
+
+ # Add 2d encoding to the feature maps.
+ latent = self.latent_encoding(latent)
+
+ # Collapse features maps height and width.
+ latent = rearrange(latent, "b c h w -> b (h w) c")
+ return latent
+
+ def decode(self, memory: Tensor, trg: Tensor) -> Tensor:
+ """Decodes image features with transformer decoder."""
+ trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index)
+ trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim)
+ trg = self.trg_position_encoding(trg)
+ out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None)
+ logits = self.head(out)
+ return logits
+
+ def predict(self, image: Tensor) -> Tensor:
+ """Transcribes text in image(s)."""
+ bsz = image.shape[0]
+ image_features = self.encode(image)
+
+ output_tokens = (torch.ones((bsz, self.max_output_length)) * self.pad_index).type_as(image).long()
+ output_tokens[:, 0] = self.start_index
+ for i in range(1, self.max_output_length):
+ trg = output_tokens[:, :i]
+ output = self.decode(image_features, trg)
+ output = torch.argmax(output, dim=-1)
+ output_tokens[:, i] = output[-1:]
+
+ # Set all tokens after end token to be padding.
+ for i in range(1, self.max_output_length):
+ indices = (output_tokens[:, i - 1] == self.end_index | (output_tokens[:, i - 1] == self.pad_index))
+ output_tokens[indices, i] = self.pad_index
+
+ return output_tokens
+
+
+
+
+
+
+
+
+
+