From 2d4714fcfeb8914f240a0d36d938b434e82f191b Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 4 Apr 2021 23:08:16 +0200 Subject: Add new transformer network --- notebooks/03-look-at-iam-paragraphs.ipynb | 16 +-- text_recognizer/data/emnist.py | 4 +- text_recognizer/data/emnist_essentials.json | 2 +- text_recognizer/networks/image_transformer.py | 159 +++++++++++++++++++++ text_recognizer/networks/transformer/__init__.py | 2 +- text_recognizer/networks/transformer/attention.py | 3 +- .../networks/transformer/positional_encoding.py | 14 +- 7 files changed, 185 insertions(+), 15 deletions(-) create mode 100644 text_recognizer/networks/image_transformer.py diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb index 73045c6..df92f99 100644 --- a/notebooks/03-look-at-iam-paragraphs.ipynb +++ b/notebooks/03-look-at-iam-paragraphs.ipynb @@ -3,7 +3,7 @@ { "cell_type": "code", "execution_count": 7, - "id": "a6f19997", + "id": "6ce2519f", "metadata": {}, "outputs": [ { @@ -40,7 +40,7 @@ { "cell_type": "code", "execution_count": 2, - "id": "abe7e727", + "id": "726ac25b", "metadata": {}, "outputs": [], "source": [ @@ -57,7 +57,7 @@ { "cell_type": "code", "execution_count": 3, - "id": "10519f10", + "id": "42501428", "metadata": {}, "outputs": [ { @@ -94,7 +94,7 @@ { "cell_type": "code", "execution_count": 4, - "id": "2672fb27", + "id": "e7778ae2", "metadata": { "scrolled": false }, @@ -172,7 +172,7 @@ { "cell_type": "code", "execution_count": 5, - "id": "8b9ef38c", + "id": "9d11ca56", "metadata": { "scrolled": false }, @@ -251,7 +251,7 @@ { "cell_type": "code", "execution_count": 8, - "id": "09b91f61", + "id": "548d10da", "metadata": {}, "outputs": [ { @@ -286,7 +286,7 @@ { "cell_type": "code", "execution_count": 9, - "id": "c883fa43", + "id": "627730b5", "metadata": { "scrolled": false }, @@ -364,7 +364,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6703bfaf", + "id": "25a074df", "metadata": {}, "outputs": [], "source": [] diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index eda490a..12adaab 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -96,7 +96,7 @@ class EMNIST(BaseDataModule): def emnist_mapping( - extra_symbols: Optional[List[str]], + extra_symbols: Optional[Sequence[str]], ) -> Tuple[List, Dict[str, int], List[int]]: """Return the EMNIST mapping.""" if not ESSENTIALS_FILENAME.exists(): @@ -209,7 +209,7 @@ def _augment_emnist_characters(characters: Sequence[str]) -> Sequence[str]: # - End token at index 2 # - Padding token at index 3 # Note: Do not forget to update NUM_SPECIAL_TOKENS if changing this! - return ["", "", "", "

", *characters, *iam_characters] + return ["", "", "", "

", *characters, *iam_characters] def download_emnist() -> None: diff --git a/text_recognizer/data/emnist_essentials.json b/text_recognizer/data/emnist_essentials.json index 3f46a73..956c28d 100644 --- a/text_recognizer/data/emnist_essentials.json +++ b/text_recognizer/data/emnist_essentials.json @@ -1 +1 @@ -{"characters": ["", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} +{"characters": ["", "", "", "

", "0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z", "a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z", " ", "!", "\"", "#", "&", "'", "(", ")", "*", "+", ",", "-", ".", "/", ":", ";", "?"], "input_shape": [28, 28]} diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py new file mode 100644 index 0000000..5a093dc --- /dev/null +++ b/text_recognizer/networks/image_transformer.py @@ -0,0 +1,159 @@ +"""A Transformer with a cnn backbone. + +The network encodes a image with a convolutional backbone to a latent representation, +i.e. feature maps. A 2d positional encoding is applied to the feature maps for +spatial information. The resulting feature are then set to a transformer decoder +together with the target tokens. + +TODO: Local attention for transformer.j + +""" +import math +from typing import Any, Dict, List, Optional, Sequence, Type + +from einops import rearrange +import torch +from torch import nn +from torch import Tensor +import torchvision + +from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.networks.transformer import ( + Decoder, + DecoderLayer, + PositionalEncoding, + PositionalEncoding2D, + target_padding_mask, +) + + +class ImageTransformer(nn.Module): + def __init__( + self, + input_shape: Sequence[int], + output_shape: Sequence[int], + backbone: Type[nn.Module], + mapping: Optional[List[str]] = None, + num_decoder_layers: int = 4, + hidden_dim: int = 256, + num_heads: int = 4, + expansion_dim: int = 4, + dropout_rate: float = 0.1, + transformer_activation: str = "glu", + ) -> None: + # Configure mapping + mapping, inverse_mapping = self._configure_mapping(mapping) + self.vocab_size = len(mapping) + self.hidden_dim = hidden_dim + self.max_output_length = output_shape[0] + self.start_index = inverse_mapping[""] + self.end_index = inverse_mapping[""] + self.pad_index = inverse_mapping["

"] + + # Image backbone + self.backbone = backbone + self.latent_encoding = PositionalEncoding2D(hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2]) + + # Target token embedding + self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim) + self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate) + + # Transformer decoder + self.decoder = Decoder( + decoder_layer=DecoderLayer( + hidden_dim=hidden_dim, + num_heads=num_heads, + expansion_dim=expansion_dim, + dropout_rate=dropout_rate, + activation=transformer_activation, + ), + num_layers=num_decoder_layers, + norm=nn.LayerNorm(hidden_dim), + ) + + # Classification head + self.head = nn.Linear(hidden_dim, self.vocab_size) + + # Initialize weights + self._init_weights() + + def _init_weights(self) -> None: + """Initialize network weights.""" + self.trg_embedding.weight.data.uniform_(-0.1, 0.1) + self.head.bias.data.zero_() + self.head.weight.data.uniform_(-0.1, 0.1) + + nn.init.kaiming_normal_(self.latent_encoding.weight.data, a=0, mode="fan_out", nonlinearity="relu") + if self.latent_encoding.bias is not None: + _, fan_out = nn.init._calculate_fan_in_and_fan_out(self.latent_encoding.weight.data) + bound = 1 / math.sqrt(fan_out) + nn.init.normal_(self.latent_encoding.bias, -bound, bound) + + def _configure_mapping(self, mapping: Optional[List[str]]) -> Tuple[List[str], Dict[str, int]]: + """Configures mapping.""" + if mapping is None: + mapping, inverse_mapping, _ = emnist_mapping() + return mapping, inverse_mapping + + def encode(self, image: Tensor) -> Tensor: + """Extracts image features with backbone. + + Args: + image (Tensor): Image(s) of handwritten text. + + Retuns: + Tensor: Image features. + + Shapes: + - image: :math: `(B, C, H, W)` + - latent: :math: `(B, T, C)` + + """ + # Extract image features. + latent = self.backbone(image) + + # Add 2d encoding to the feature maps. + latent = self.latent_encoding(latent) + + # Collapse features maps height and width. + latent = rearrange(latent, "b c h w -> b (h w) c") + return latent + + def decode(self, memory: Tensor, trg: Tensor) -> Tensor: + """Decodes image features with transformer decoder.""" + trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index) + trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim) + trg = self.trg_position_encoding(trg) + out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None) + logits = self.head(out) + return logits + + def predict(self, image: Tensor) -> Tensor: + """Transcribes text in image(s).""" + bsz = image.shape[0] + image_features = self.encode(image) + + output_tokens = (torch.ones((bsz, self.max_output_length)) * self.pad_index).type_as(image).long() + output_tokens[:, 0] = self.start_index + for i in range(1, self.max_output_length): + trg = output_tokens[:, :i] + output = self.decode(image_features, trg) + output = torch.argmax(output, dim=-1) + output_tokens[:, i] = output[-1:] + + # Set all tokens after end token to be padding. + for i in range(1, self.max_output_length): + indices = (output_tokens[:, i - 1] == self.end_index | (output_tokens[:, i - 1] == self.pad_index)) + output_tokens[indices, i] = self.pad_index + + return output_tokens + + + + + + + + + + diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py index 9febc88..139cd23 100644 --- a/text_recognizer/networks/transformer/__init__.py +++ b/text_recognizer/networks/transformer/__init__.py @@ -1,3 +1,3 @@ """Transformer modules.""" -from .positional_encoding import PositionalEncoding +from .positional_encoding import PositionalEncoding, PositionalEncoding2D, target_padding_mask from .transformer import Decoder, Encoder, EncoderLayer, Transformer diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index cce1ecc..ac75d2f 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -50,8 +50,9 @@ class MultiHeadAttention(nn.Module): ) nn.init.xavier_normal_(self.fc_out.weight) + @staticmethod def scaled_dot_product_attention( - self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None + query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None ) -> Tensor: """Calculates the scaled dot product attention.""" diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py index d67d297..dbde887 100644 --- a/text_recognizer/networks/transformer/positional_encoding.py +++ b/text_recognizer/networks/transformer/positional_encoding.py @@ -56,9 +56,9 @@ class PositionalEncoding2D(nn.Module): pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w) pe_w = PositionalEncoding.make_pe( - hidden_dim // 2, max_len=max_h + hidden_dim // 2, max_len=max_w ) # [W, 1, D // 2] - pe_w = repeat(pe_w, "h w d -> d (h tile) w", tile=max_h) + pe_w = repeat(pe_w, "w h d -> d (h tile) w", tile=max_h) pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W] return pe @@ -70,3 +70,13 @@ class PositionalEncoding2D(nn.Module): raise ValueError("Hidden dimensions does not match.") x += self.pe[:, : x.shape[2], : x.shape[3]] return x + +def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor: + """Returns causal target mask.""" + trg_pad_mask = (trg != pad_index)[:, None, None] + trg_len = trg.shape[1] + trg_sub_mask = torch.tril( + torch.ones((trg_len, trg_len), device=trg.device) + ).bool() + trg_mask = trg_pad_mask & trg_sub_mask + return trg_mask -- cgit v1.2.3-70-g09d2