summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/transformer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-09 22:46:09 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-05-09 22:46:09 +0200
commitc9c60678673e19ad3367339eb8e7a093e5a98474 (patch)
treeb787a7fbb535c2ee44f935720d75034cc24ffd30 /text_recognizer/networks/transformer
parenta2a3133ed5da283888efbdb9924d0e3733c274c8 (diff)
Reformatting of positional encodings and ViT working
Diffstat (limited to 'text_recognizer/networks/transformer')
-rw-r--r--text_recognizer/networks/transformer/__init__.py7
-rw-r--r--text_recognizer/networks/transformer/attention.py4
-rw-r--r--text_recognizer/networks/transformer/layers.py29
-rw-r--r--text_recognizer/networks/transformer/nystromer/attention.py4
-rw-r--r--text_recognizer/networks/transformer/positional_encodings/__init__.py4
-rw-r--r--text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py16
-rw-r--r--text_recognizer/networks/transformer/positional_encodings/positional_encoding.py (renamed from text_recognizer/networks/transformer/positional_encoding.py)0
-rw-r--r--text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py (renamed from text_recognizer/networks/transformer/rotary_embedding.py)0
-rw-r--r--text_recognizer/networks/transformer/transformer.py321
-rw-r--r--text_recognizer/networks/transformer/vit.py46
10 files changed, 153 insertions, 278 deletions
diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py
index 4ff48f7..a3f3011 100644
--- a/text_recognizer/networks/transformer/__init__.py
+++ b/text_recognizer/networks/transformer/__init__.py
@@ -1,8 +1 @@
"""Transformer modules."""
-from .positional_encoding import (
- PositionalEncoding,
- PositionalEncoding2D,
- target_padding_mask,
-)
-
-# from .transformer import Decoder, DecoderLayer, Encoder, EncoderLayer, Transformer
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 623d680..eabeadf 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -9,7 +9,9 @@ from torch import nn
from torch import Tensor
import torch.nn.functional as F
-from text_recognizer.networks.transformer.rotary_embedding import apply_rotary_pos_emb
+from text_recognizer.networks.transformer.positional_encodings.rotary_embedding import (
+ apply_rotary_pos_emb,
+)
class Attention(nn.Module):
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index a2fdb1a..4063425 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -1,6 +1,6 @@
"""Generates the attention layer architecture."""
from functools import partial
-from typing import Dict, Optional, Type
+from typing import Any, Dict, Optional, Type
from click.types import Tuple
@@ -36,12 +36,11 @@ class AttentionLayers(nn.Module):
norm_fn = partial(norm_fn, dim=dim)
ff_fn = partial(ff_fn, dim=dim, **ff_kwargs)
self.layer_types = self._get_layer_types(cross_attend) * depth
- self.layers = self._build_network(
- causal, attn_fn, norm_fn, ff_fn, residual_fn
- )
+ self.layers = self._build_network(causal, attn_fn, norm_fn, ff_fn, residual_fn)
rotary_emb_dim = max(rotary_emb_dim, 32) if rotary_emb_dim is not None else None
self.rotary_emb = RotaryEmbedding(rotary_emb_dim) if rotary_emb else None
self.pre_norm = pre_norm
+ self.has_pos_emb = True if self.rotary_emb is not None else False
@staticmethod
def _get_layer_types(cross_attend: bool) -> Tuple:
@@ -70,7 +69,7 @@ class AttentionLayers(nn.Module):
residual_fn = residual_fn()
- layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
+ layers.append(nn.modulelist([norm_fn(), layer, residual_fn]))
return layers
def forward(
@@ -82,10 +81,12 @@ class AttentionLayers(nn.Module):
) -> Tensor:
rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None
- for i, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
+ for i, (layer_type, (norm, block, residual_fn)) in enumerate(
+ zip(self.layer_types, self.layers)
+ ):
is_last = i == len(self.layers) - 1
-
- residual = x
+
+ residual = x
if self.pre_norm:
x = norm(x)
@@ -103,3 +104,15 @@ class AttentionLayers(nn.Module):
x = norm(x)
return x
+
+
+class Encoder(AttentionLayers):
+ def __init__(self, **kwargs: Any) -> None:
+ assert "causal" not in kwargs, "Cannot set causality on encoder"
+ super().__init__(causal=False, **kwargs)
+
+
+class Decoder(AttentionLayers):
+ def __init__(self, **kwargs: Any) -> None:
+ assert "causal" not in kwargs, "Cannot set causality on decoder"
+ super().__init__(causal=True, **kwargs)
diff --git a/text_recognizer/networks/transformer/nystromer/attention.py b/text_recognizer/networks/transformer/nystromer/attention.py
index c2871fb..5ab19cf 100644
--- a/text_recognizer/networks/transformer/nystromer/attention.py
+++ b/text_recognizer/networks/transformer/nystromer/attention.py
@@ -157,14 +157,14 @@ class NystromAttention(nn.Module):
self, x: Tensor, mask: Optional[Tensor] = None, return_attn: bool = False
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
"""Compute the Nystrom attention."""
- _, n, _, h, m = x.shape, self.num_heads
+ _, n, _, h, m = *x.shape, self.num_heads, self.num_landmarks
if n % m != 0:
x, mask = self._pad_sequence(x, mask, n, m)
q, k, v = self.qkv_fn(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
- q *= self.scale
+ q = q * self.scale
out, attn = self._nystrom_attention(q, k, v, mask, n, m, return_attn)
diff --git a/text_recognizer/networks/transformer/positional_encodings/__init__.py b/text_recognizer/networks/transformer/positional_encodings/__init__.py
new file mode 100644
index 0000000..91278ee
--- /dev/null
+++ b/text_recognizer/networks/transformer/positional_encodings/__init__.py
@@ -0,0 +1,4 @@
+"""Positional encoding for transformers."""
+from .absolute_embedding import AbsolutePositionalEmbedding
+from .positional_encoding import PositionalEncoding, PositionalEncoding2D
+from .rotary_embedding import apply_rotary_pos_emb, RotaryEmbedding
diff --git a/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py b/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py
new file mode 100644
index 0000000..9466f6e
--- /dev/null
+++ b/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py
@@ -0,0 +1,16 @@
+"""Absolute positional embedding."""
+from torch import nn, Tensor
+
+
+class AbsolutePositionalEmbedding(nn.Module):
+ def __init__(self, dim: int, max_seq_len: int) -> None:
+ super().__init__()
+ self.emb = nn.Embedding(max_seq_len, dim)
+ self._weight_init()
+
+ def _weight_init(self) -> None:
+ nn.init.normal_(self.emb.weight, std=0.02)
+
+ def forward(self, x: Tensor) -> Tensor:
+ n = torch.arange(x.shape[1], device=x.device)
+ return self.emb(n)[None, :, :]
diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encodings/positional_encoding.py
index c50afc3..c50afc3 100644
--- a/text_recognizer/networks/transformer/positional_encoding.py
+++ b/text_recognizer/networks/transformer/positional_encodings/positional_encoding.py
diff --git a/text_recognizer/networks/transformer/rotary_embedding.py b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py
index 5e80572..5e80572 100644
--- a/text_recognizer/networks/transformer/rotary_embedding.py
+++ b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py
diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py
index d49c85a..36f86ac 100644
--- a/text_recognizer/networks/transformer/transformer.py
+++ b/text_recognizer/networks/transformer/transformer.py
@@ -1,260 +1,61 @@
-# """Transfomer module."""
-# import copy
-# from typing import Dict, Optional, Type, Union
-#
-# import numpy as np
-# import torch
-# from torch import nn
-# from torch import Tensor
-# import torch.nn.functional as F
-#
-# from text_recognizer.networks.transformer.attention import MultiHeadAttention
-# from text_recognizer.networks.util import activation_function
-#
-#
-# class GEGLU(nn.Module):
-# """GLU activation for improving feedforward activations."""
-#
-# def __init__(self, dim_in: int, dim_out: int) -> None:
-# super().__init__()
-# self.proj = nn.Linear(dim_in, dim_out * 2)
-#
-# def forward(self, x: Tensor) -> Tensor:
-# """Forward propagation."""
-# x, gate = self.proj(x).chunk(2, dim=-1)
-# return x * F.gelu(gate)
-#
-#
-# def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList:
-# return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)])
-#
-#
-# class _IntraLayerConnection(nn.Module):
-# """Preforms the residual connection inside the transfomer blocks and applies layernorm."""
-#
-# def __init__(self, dropout_rate: float, hidden_dim: int) -> None:
-# super().__init__()
-# self.norm = nn.LayerNorm(normalized_shape=hidden_dim)
-# self.dropout = nn.Dropout(p=dropout_rate)
-#
-# def forward(self, src: Tensor, residual: Tensor) -> Tensor:
-# return self.norm(self.dropout(src) + residual)
-#
-#
-# class FeedForward(nn.Module):
-# def __init__(
-# self,
-# hidden_dim: int,
-# expansion_dim: int,
-# dropout_rate: float,
-# activation: str = "relu",
-# ) -> None:
-# super().__init__()
-#
-# in_projection = (
-# nn.Sequential(
-# nn.Linear(hidden_dim, expansion_dim), activation_function(activation)
-# )
-# if activation != "glu"
-# else GEGLU(hidden_dim, expansion_dim)
-# )
-#
-# self.layer = nn.Sequential(
-# in_projection,
-# nn.Dropout(p=dropout_rate),
-# nn.Linear(in_features=expansion_dim, out_features=hidden_dim),
-# )
-#
-# def forward(self, x: Tensor) -> Tensor:
-# return self.layer(x)
-#
-#
-# class EncoderLayer(nn.Module):
-# """Transfomer encoding layer."""
-#
-# def __init__(
-# self,
-# hidden_dim: int,
-# num_heads: int,
-# expansion_dim: int,
-# dropout_rate: float,
-# activation: str = "relu",
-# ) -> None:
-# super().__init__()
-# self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
-# self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation)
-# self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
-# self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
-#
-# def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
-# """Forward pass through the encoder."""
-# # First block.
-# # Multi head attention.
-# out, _ = self.self_attention(src, src, src, mask)
-#
-# # Add & norm.
-# out = self.block1(out, src)
-#
-# # Second block.
-# # Apply 1D-convolution.
-# mlp_out = self.mlp(out)
-#
-# # Add & norm.
-# out = self.block2(mlp_out, out)
-#
-# return out
-#
-#
-# class Encoder(nn.Module):
-# """Transfomer encoder module."""
-#
-# def __init__(
-# self,
-# num_layers: int,
-# encoder_layer: Type[nn.Module],
-# norm: Optional[Type[nn.Module]] = None,
-# ) -> None:
-# super().__init__()
-# self.layers = _get_clones(encoder_layer, num_layers)
-# self.norm = norm
-#
-# def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
-# """Forward pass through all encoder layers."""
-# for layer in self.layers:
-# src = layer(src, src_mask)
-#
-# if self.norm is not None:
-# src = self.norm(src)
-#
-# return src
-#
-#
-# class DecoderLayer(nn.Module):
-# """Transfomer decoder layer."""
-#
-# def __init__(
-# self,
-# hidden_dim: int,
-# num_heads: int,
-# expansion_dim: int,
-# dropout_rate: float = 0.0,
-# activation: str = "relu",
-# ) -> None:
-# super().__init__()
-# self.hidden_dim = hidden_dim
-# self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
-# self.multihead_attention = MultiHeadAttention(
-# hidden_dim, num_heads, dropout_rate
-# )
-# self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation)
-# self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
-# self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
-# self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim)
-#
-# def forward(
-# self,
-# trg: Tensor,
-# memory: Tensor,
-# trg_mask: Optional[Tensor] = None,
-# memory_mask: Optional[Tensor] = None,
-# ) -> Tensor:
-# """Forward pass of the layer."""
-# out, _ = self.self_attention(trg, trg, trg, trg_mask)
-# trg = self.block1(out, trg)
-#
-# out, _ = self.multihead_attention(trg, memory, memory, memory_mask)
-# trg = self.block2(out, trg)
-#
-# out = self.mlp(trg)
-# out = self.block3(out, trg)
-#
-# return out
-#
-#
-# class Decoder(nn.Module):
-# """Transfomer decoder module."""
-#
-# def __init__(
-# self,
-# decoder_layer: Type[nn.Module],
-# num_layers: int,
-# norm: Optional[Type[nn.Module]] = None,
-# ) -> None:
-# super().__init__()
-# self.layers = _get_clones(decoder_layer, num_layers)
-# self.num_layers = num_layers
-# self.norm = norm
-#
-# def forward(
-# self,
-# trg: Tensor,
-# memory: Tensor,
-# trg_mask: Optional[Tensor] = None,
-# memory_mask: Optional[Tensor] = None,
-# ) -> Tensor:
-# """Forward pass through the decoder."""
-# for layer in self.layers:
-# trg = layer(trg, memory, trg_mask, memory_mask)
-#
-# if self.norm is not None:
-# trg = self.norm(trg)
-#
-# return trg
-#
-#
-# class Transformer(nn.Module):
-# """Transformer network."""
-#
-# def __init__(
-# self,
-# num_encoder_layers: int,
-# num_decoder_layers: int,
-# hidden_dim: int,
-# num_heads: int,
-# expansion_dim: int,
-# dropout_rate: float,
-# activation: str = "relu",
-# ) -> None:
-# super().__init__()
-#
-# # Configure encoder.
-# encoder_norm = nn.LayerNorm(hidden_dim)
-# encoder_layer = EncoderLayer(
-# hidden_dim, num_heads, expansion_dim, dropout_rate, activation
-# )
-# self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm)
-#
-# # Configure decoder.
-# decoder_norm = nn.LayerNorm(hidden_dim)
-# decoder_layer = DecoderLayer(
-# hidden_dim, num_heads, expansion_dim, dropout_rate, activation
-# )
-# self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm)
-#
-# self._reset_parameters()
-#
-# def _reset_parameters(self) -> None:
-# for p in self.parameters():
-# if p.dim() > 1:
-# nn.init.xavier_uniform_(p)
-#
-# def forward(
-# self,
-# src: Tensor,
-# trg: Tensor,
-# src_mask: Optional[Tensor] = None,
-# trg_mask: Optional[Tensor] = None,
-# memory_mask: Optional[Tensor] = None,
-# ) -> Tensor:
-# """Forward pass through the transformer."""
-# if src.shape[0] != trg.shape[0]:
-# print(trg.shape)
-# raise RuntimeError("The batch size of the src and trg must be the same.")
-# if src.shape[2] != trg.shape[2]:
-# raise RuntimeError(
-# "The number of features for the src and trg must be the same."
-# )
-#
-# memory = self.encoder(src, src_mask)
-# output = self.decoder(trg, memory, trg_mask, memory_mask)
-# return output
+"""Transformer wrapper."""
+from typing import Optional, Type
+
+from torch import nn, Tensor
+
+from .layers import AttentionLayers
+from text_recognizer.networks.transformer.positional_encodings import (
+ AbsolutePositionalEmbedding,
+)
+
+
+class Transformer(nn.Module):
+ def __init__(
+ self,
+ num_tokens: int,
+ max_seq_len: int,
+ attn_layers: Type[AttentionLayers],
+ emb_dim: Optional[int] = None,
+ emb_dropout: float = 0.0,
+ use_pos_emb: bool = True,
+ ) -> None:
+ dim = attn_layers.dim
+ emb_dim = emb_dim if emb_dim is not None else dim
+ self.max_seq_len = max_seq_len
+
+ self.token_emb = nn.Embedding(num_tokens, emb_dim)
+ self.emb_dropout = nn.Dropout(emb_dropout)
+ self.pos_emb = (
+ AbsolutePositionalEmbedding(emb_dim, max_seq_len)
+ if (use_pos_emb and not self.attn_layers.has_pos_emb)
+ else None
+ )
+
+ self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
+ self.attn_layers = attn_layers
+ self.norm = nn.LayerNorm(dim)
+
+ self._init_weights()
+
+ self.logits = nn.Linear(dim, num_tokens)
+
+ def _init_weights(self) -> None:
+ nn.init.normal_(self.token_emb.weight, std=0.02)
+
+ def forward(
+ self,
+ x: Tensor,
+ mask: Optional[Tensor],
+ return_embeddings: bool = False,
+ **kwargs: Any
+ ) -> Tensor:
+ b, n, device = *x.shape, x.device
+ x += self.token_emb(x)
+ if self.pos_emb is not None:
+ x += self.pos_emb(x)
+ x = self.emb_dropout(x)
+
+ x = self.project_emb(x)
+ x = self.attn_layers(x, mask=mask, **kwargs)
+ out = self.logits(x) if not return_embeddings else x
+ return x
diff --git a/text_recognizer/networks/transformer/vit.py b/text_recognizer/networks/transformer/vit.py
index e69de29..ab331f8 100644
--- a/text_recognizer/networks/transformer/vit.py
+++ b/text_recognizer/networks/transformer/vit.py
@@ -0,0 +1,46 @@
+"""Vision Transformer."""
+from typing import Tuple, Type
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn, Tensor
+
+
+class ViT(nn.Module):
+ def __init__(
+ self,
+ image_size: Tuple[int, int],
+ patch_size: Tuple[int, int],
+ dim: int,
+ transformer: Type[nn.Module],
+ channels: int = 1,
+ ) -> None:
+ super().__init__()
+ img_height, img_width = image_size
+ patch_height, patch_width = patch_size
+ assert img_height % patch_height == 0
+ assert img_width % patch_width == 0
+
+ num_patches = (img_height // patch_height) * (img_width // patch_width)
+ patch_dim = channels * patch_height * patch_width
+
+ self.to_patch_embedding = nn.Sequential(
+ Rearrange(
+ "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
+ p1=patch_height,
+ p2=patch_width,
+ c=channels,
+ ),
+ nn.Linear(patch_dim, dim),
+ )
+
+ self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
+ self.transformer = transformer
+ self.norm = nn.LayerNorm(dim)
+
+ def forward(self, img: Tensor) -> Tensor:
+ x = self.to_patch_embedding(img)
+ _, n, _ = x.shape
+ x += self.pos_embedding[:, :n]
+ x = self.transformer(x)
+ return x