summaryrefslogtreecommitdiff
path: root/text_recognizer/networks
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks')
-rw-r--r--text_recognizer/networks/image_transformer.py159
-rw-r--r--text_recognizer/networks/transformer/__init__.py2
-rw-r--r--text_recognizer/networks/transformer/attention.py3
-rw-r--r--text_recognizer/networks/transformer/positional_encoding.py14
4 files changed, 174 insertions, 4 deletions
diff --git a/text_recognizer/networks/image_transformer.py b/text_recognizer/networks/image_transformer.py
new file mode 100644
index 0000000..5a093dc
--- /dev/null
+++ b/text_recognizer/networks/image_transformer.py
@@ -0,0 +1,159 @@
+"""A Transformer with a cnn backbone.
+
+The network encodes a image with a convolutional backbone to a latent representation,
+i.e. feature maps. A 2d positional encoding is applied to the feature maps for
+spatial information. The resulting feature are then set to a transformer decoder
+together with the target tokens.
+
+TODO: Local attention for transformer.j
+
+"""
+import math
+from typing import Any, Dict, List, Optional, Sequence, Type
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+import torchvision
+
+from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.networks.transformer import (
+ Decoder,
+ DecoderLayer,
+ PositionalEncoding,
+ PositionalEncoding2D,
+ target_padding_mask,
+)
+
+
+class ImageTransformer(nn.Module):
+ def __init__(
+ self,
+ input_shape: Sequence[int],
+ output_shape: Sequence[int],
+ backbone: Type[nn.Module],
+ mapping: Optional[List[str]] = None,
+ num_decoder_layers: int = 4,
+ hidden_dim: int = 256,
+ num_heads: int = 4,
+ expansion_dim: int = 4,
+ dropout_rate: float = 0.1,
+ transformer_activation: str = "glu",
+ ) -> None:
+ # Configure mapping
+ mapping, inverse_mapping = self._configure_mapping(mapping)
+ self.vocab_size = len(mapping)
+ self.hidden_dim = hidden_dim
+ self.max_output_length = output_shape[0]
+ self.start_index = inverse_mapping["<s>"]
+ self.end_index = inverse_mapping["<e>"]
+ self.pad_index = inverse_mapping["<p>"]
+
+ # Image backbone
+ self.backbone = backbone
+ self.latent_encoding = PositionalEncoding2D(hidden_dim=hidden_dim, max_h=input_shape[1], max_w=input_shape[2])
+
+ # Target token embedding
+ self.trg_embedding = nn.Embedding(self.vocab_size, hidden_dim)
+ self.trg_position_encoding = PositionalEncoding(hidden_dim, dropout_rate)
+
+ # Transformer decoder
+ self.decoder = Decoder(
+ decoder_layer=DecoderLayer(
+ hidden_dim=hidden_dim,
+ num_heads=num_heads,
+ expansion_dim=expansion_dim,
+ dropout_rate=dropout_rate,
+ activation=transformer_activation,
+ ),
+ num_layers=num_decoder_layers,
+ norm=nn.LayerNorm(hidden_dim),
+ )
+
+ # Classification head
+ self.head = nn.Linear(hidden_dim, self.vocab_size)
+
+ # Initialize weights
+ self._init_weights()
+
+ def _init_weights(self) -> None:
+ """Initialize network weights."""
+ self.trg_embedding.weight.data.uniform_(-0.1, 0.1)
+ self.head.bias.data.zero_()
+ self.head.weight.data.uniform_(-0.1, 0.1)
+
+ nn.init.kaiming_normal_(self.latent_encoding.weight.data, a=0, mode="fan_out", nonlinearity="relu")
+ if self.latent_encoding.bias is not None:
+ _, fan_out = nn.init._calculate_fan_in_and_fan_out(self.latent_encoding.weight.data)
+ bound = 1 / math.sqrt(fan_out)
+ nn.init.normal_(self.latent_encoding.bias, -bound, bound)
+
+ def _configure_mapping(self, mapping: Optional[List[str]]) -> Tuple[List[str], Dict[str, int]]:
+ """Configures mapping."""
+ if mapping is None:
+ mapping, inverse_mapping, _ = emnist_mapping()
+ return mapping, inverse_mapping
+
+ def encode(self, image: Tensor) -> Tensor:
+ """Extracts image features with backbone.
+
+ Args:
+ image (Tensor): Image(s) of handwritten text.
+
+ Retuns:
+ Tensor: Image features.
+
+ Shapes:
+ - image: :math: `(B, C, H, W)`
+ - latent: :math: `(B, T, C)`
+
+ """
+ # Extract image features.
+ latent = self.backbone(image)
+
+ # Add 2d encoding to the feature maps.
+ latent = self.latent_encoding(latent)
+
+ # Collapse features maps height and width.
+ latent = rearrange(latent, "b c h w -> b (h w) c")
+ return latent
+
+ def decode(self, memory: Tensor, trg: Tensor) -> Tensor:
+ """Decodes image features with transformer decoder."""
+ trg_mask = target_padding_mask(trg=trg, pad_index=self.pad_index)
+ trg = self.trg_embedding(trg) * math.sqrt(self.hidden_dim)
+ trg = self.trg_position_encoding(trg)
+ out = self.decoder(trg=trg, memory=memory, trg_mask=trg_mask, memory_mask=None)
+ logits = self.head(out)
+ return logits
+
+ def predict(self, image: Tensor) -> Tensor:
+ """Transcribes text in image(s)."""
+ bsz = image.shape[0]
+ image_features = self.encode(image)
+
+ output_tokens = (torch.ones((bsz, self.max_output_length)) * self.pad_index).type_as(image).long()
+ output_tokens[:, 0] = self.start_index
+ for i in range(1, self.max_output_length):
+ trg = output_tokens[:, :i]
+ output = self.decode(image_features, trg)
+ output = torch.argmax(output, dim=-1)
+ output_tokens[:, i] = output[-1:]
+
+ # Set all tokens after end token to be padding.
+ for i in range(1, self.max_output_length):
+ indices = (output_tokens[:, i - 1] == self.end_index | (output_tokens[:, i - 1] == self.pad_index))
+ output_tokens[indices, i] = self.pad_index
+
+ return output_tokens
+
+
+
+
+
+
+
+
+
+
diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py
index 9febc88..139cd23 100644
--- a/text_recognizer/networks/transformer/__init__.py
+++ b/text_recognizer/networks/transformer/__init__.py
@@ -1,3 +1,3 @@
"""Transformer modules."""
-from .positional_encoding import PositionalEncoding
+from .positional_encoding import PositionalEncoding, PositionalEncoding2D, target_padding_mask
from .transformer import Decoder, Encoder, EncoderLayer, Transformer
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index cce1ecc..ac75d2f 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -50,8 +50,9 @@ class MultiHeadAttention(nn.Module):
)
nn.init.xavier_normal_(self.fc_out.weight)
+ @staticmethod
def scaled_dot_product_attention(
- self, query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
+ query: Tensor, key: Tensor, value: Tensor, mask: Optional[Tensor] = None
) -> Tensor:
"""Calculates the scaled dot product attention."""
diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py
index d67d297..dbde887 100644
--- a/text_recognizer/networks/transformer/positional_encoding.py
+++ b/text_recognizer/networks/transformer/positional_encoding.py
@@ -56,9 +56,9 @@ class PositionalEncoding2D(nn.Module):
pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w)
pe_w = PositionalEncoding.make_pe(
- hidden_dim // 2, max_len=max_h
+ hidden_dim // 2, max_len=max_w
) # [W, 1, D // 2]
- pe_w = repeat(pe_w, "h w d -> d (h tile) w", tile=max_h)
+ pe_w = repeat(pe_w, "w h d -> d (h tile) w", tile=max_h)
pe = torch.cat([pe_h, pe_w], dim=0) # [D, H, W]
return pe
@@ -70,3 +70,13 @@ class PositionalEncoding2D(nn.Module):
raise ValueError("Hidden dimensions does not match.")
x += self.pe[:, : x.shape[2], : x.shape[3]]
return x
+
+def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor:
+ """Returns causal target mask."""
+ trg_pad_mask = (trg != pad_index)[:, None, None]
+ trg_len = trg.shape[1]
+ trg_sub_mask = torch.tril(
+ torch.ones((trg_len, trg_len), device=trg.device)
+ ).bool()
+ trg_mask = trg_pad_mask & trg_sub_mask
+ return trg_mask