From c9c60678673e19ad3367339eb8e7a093e5a98474 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sun, 9 May 2021 22:46:09 +0200
Subject: Reformatting of positional encodings and ViT working

---
 text_recognizer/networks/transformer/__init__.py   |   7 -
 text_recognizer/networks/transformer/attention.py  |   4 +-
 text_recognizer/networks/transformer/layers.py     |  29 +-
 .../networks/transformer/nystromer/attention.py    |   4 +-
 .../networks/transformer/positional_encoding.py    |  85 ------
 .../transformer/positional_encodings/__init__.py   |   4 +
 .../positional_encodings/absolute_embedding.py     |  16 +
 .../positional_encodings/positional_encoding.py    |  85 ++++++
 .../positional_encodings/rotary_embedding.py       |  39 +++
 .../networks/transformer/rotary_embedding.py       |  39 ---
 .../networks/transformer/transformer.py            | 321 ++++-----------------
 text_recognizer/networks/transformer/vit.py        |  46 +++
 12 files changed, 277 insertions(+), 402 deletions(-)
 delete mode 100644 text_recognizer/networks/transformer/positional_encoding.py
 create mode 100644 text_recognizer/networks/transformer/positional_encodings/__init__.py
 create mode 100644 text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py
 create mode 100644 text_recognizer/networks/transformer/positional_encodings/positional_encoding.py
 create mode 100644 text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py
 delete mode 100644 text_recognizer/networks/transformer/rotary_embedding.py

(limited to 'text_recognizer/networks/transformer')

diff --git a/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py
index 4ff48f7..a3f3011 100644
--- a/text_recognizer/networks/transformer/__init__.py
+++ b/text_recognizer/networks/transformer/__init__.py
@@ -1,8 +1 @@
 """Transformer modules."""
-from .positional_encoding import (
-    PositionalEncoding,
-    PositionalEncoding2D,
-    target_padding_mask,
-)
-
-# from .transformer import Decoder, DecoderLayer, Encoder, EncoderLayer, Transformer
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 623d680..eabeadf 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -9,7 +9,9 @@ from torch import nn
 from torch import Tensor
 import torch.nn.functional as F
 
-from text_recognizer.networks.transformer.rotary_embedding import apply_rotary_pos_emb
+from text_recognizer.networks.transformer.positional_encodings.rotary_embedding import (
+    apply_rotary_pos_emb,
+)
 
 
 class Attention(nn.Module):
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index a2fdb1a..4063425 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -1,6 +1,6 @@
 """Generates the attention layer architecture."""
 from functools import partial
-from typing import Dict, Optional, Type
+from typing import Any, Dict, Optional, Type
 
 from click.types import Tuple
 
@@ -36,12 +36,11 @@ class AttentionLayers(nn.Module):
         norm_fn = partial(norm_fn, dim=dim)
         ff_fn = partial(ff_fn, dim=dim, **ff_kwargs)
         self.layer_types = self._get_layer_types(cross_attend) * depth
-        self.layers = self._build_network(
-            causal, attn_fn, norm_fn, ff_fn, residual_fn
-        )
+        self.layers = self._build_network(causal, attn_fn, norm_fn, ff_fn, residual_fn)
         rotary_emb_dim = max(rotary_emb_dim, 32) if rotary_emb_dim is not None else None
         self.rotary_emb = RotaryEmbedding(rotary_emb_dim) if rotary_emb else None
         self.pre_norm = pre_norm
+        self.has_pos_emb = True if self.rotary_emb is not None else False
 
     @staticmethod
     def _get_layer_types(cross_attend: bool) -> Tuple:
@@ -70,7 +69,7 @@ class AttentionLayers(nn.Module):
 
             residual_fn = residual_fn()
 
-            layers.append(nn.ModuleList([norm_fn(), layer, residual_fn]))
+            layers.append(nn.modulelist([norm_fn(), layer, residual_fn]))
         return layers
 
     def forward(
@@ -82,10 +81,12 @@ class AttentionLayers(nn.Module):
     ) -> Tensor:
         rotary_pos_emb = self.rotary_emb(x) if self.rotary_emb is not None else None
 
-        for i, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
+        for i, (layer_type, (norm, block, residual_fn)) in enumerate(
+            zip(self.layer_types, self.layers)
+        ):
             is_last = i == len(self.layers) - 1
-            
-           residual = x 
+
+            residual = x
 
             if self.pre_norm:
                 x = norm(x)
@@ -103,3 +104,15 @@ class AttentionLayers(nn.Module):
                 x = norm(x)
 
         return x
+
+
+class Encoder(AttentionLayers):
+    def __init__(self, **kwargs: Any) -> None:
+        assert "causal" not in kwargs, "Cannot set causality on encoder"
+        super().__init__(causal=False, **kwargs)
+
+
+class Decoder(AttentionLayers):
+    def __init__(self, **kwargs: Any) -> None:
+        assert "causal" not in kwargs, "Cannot set causality on decoder"
+        super().__init__(causal=True, **kwargs)
diff --git a/text_recognizer/networks/transformer/nystromer/attention.py b/text_recognizer/networks/transformer/nystromer/attention.py
index c2871fb..5ab19cf 100644
--- a/text_recognizer/networks/transformer/nystromer/attention.py
+++ b/text_recognizer/networks/transformer/nystromer/attention.py
@@ -157,14 +157,14 @@ class NystromAttention(nn.Module):
         self, x: Tensor, mask: Optional[Tensor] = None, return_attn: bool = False
     ) -> Union[Tensor, Tuple[Tensor, Tensor]]:
         """Compute the Nystrom attention."""
-        _, n, _, h, m = x.shape, self.num_heads
+        _, n, _, h, m = *x.shape, self.num_heads, self.num_landmarks
         if n % m != 0:
             x, mask = self._pad_sequence(x, mask, n, m)
 
         q, k, v = self.qkv_fn(x).chunk(3, dim=-1)
         q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
 
-        q *= self.scale
+        q = q * self.scale
 
         out, attn = self._nystrom_attention(q, k, v, mask, n, m, return_attn)
 
diff --git a/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py
deleted file mode 100644
index c50afc3..0000000
--- a/text_recognizer/networks/transformer/positional_encoding.py
+++ /dev/null
@@ -1,85 +0,0 @@
-"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
-from einops import repeat
-import numpy as np
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class PositionalEncoding(nn.Module):
-    """Encodes a sense of distance or time for transformer networks."""
-
-    def __init__(
-        self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
-    ) -> None:
-        super().__init__()
-        self.dropout = nn.Dropout(p=dropout_rate)
-        pe = self.make_pe(hidden_dim, max_len)
-        self.register_buffer("pe", pe)
-
-    @staticmethod
-    def make_pe(hidden_dim: int, max_len: int) -> Tensor:
-        """Returns positional encoding."""
-        pe = torch.zeros(max_len, hidden_dim)
-        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
-        div_term = torch.exp(
-            torch.arange(0, hidden_dim, 2).float() * (-np.log(10000.0) / hidden_dim)
-        )
-
-        pe[:, 0::2] = torch.sin(position * div_term)
-        pe[:, 1::2] = torch.cos(position * div_term)
-        pe = pe.unsqueeze(1)
-        return pe
-
-    def forward(self, x: Tensor) -> Tensor:
-        """Encodes the tensor with a postional embedding."""
-        # [T, B, D]
-        if x.shape[2] != self.pe.shape[2]:
-            raise ValueError(f"x shape does not match pe in the 3rd dim.")
-        x = x + self.pe[: x.shape[0]]
-        return self.dropout(x)
-
-
-class PositionalEncoding2D(nn.Module):
-    """Positional encodings for feature maps."""
-
-    def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int = 2048) -> None:
-        super().__init__()
-        if hidden_dim % 2 != 0:
-            raise ValueError(f"Embedding depth {hidden_dim} is not even!")
-        self.hidden_dim = hidden_dim
-        pe = self.make_pe(hidden_dim, max_h, max_w)
-        self.register_buffer("pe", pe)
-
-    @staticmethod
-    def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor:
-        """Returns 2d postional encoding."""
-        pe_h = PositionalEncoding.make_pe(
-            hidden_dim // 2, max_len=max_h
-        )  # [H, 1, D // 2]
-        pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w)
-
-        pe_w = PositionalEncoding.make_pe(
-            hidden_dim // 2, max_len=max_w
-        )  # [W, 1, D // 2]
-        pe_w = repeat(pe_w, "w h d -> d (h tile) w", tile=max_h)
-
-        pe = torch.cat([pe_h, pe_w], dim=0)  # [D, H, W]
-        return pe
-
-    def forward(self, x: Tensor) -> Tensor:
-        """Adds 2D postional encoding to input tensor."""
-        # Assumes x hase shape [B, D, H, W]
-        if x.shape[1] != self.pe.shape[0]:
-            raise ValueError("Hidden dimensions does not match.")
-        x += self.pe[:, : x.shape[2], : x.shape[3]]
-        return x
-
-
-def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor:
-    """Returns causal target mask."""
-    trg_pad_mask = (trg != pad_index)[:, None, None]
-    trg_len = trg.shape[1]
-    trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool()
-    trg_mask = trg_pad_mask & trg_sub_mask
-    return trg_mask
diff --git a/text_recognizer/networks/transformer/positional_encodings/__init__.py b/text_recognizer/networks/transformer/positional_encodings/__init__.py
new file mode 100644
index 0000000..91278ee
--- /dev/null
+++ b/text_recognizer/networks/transformer/positional_encodings/__init__.py
@@ -0,0 +1,4 @@
+"""Positional encoding for transformers."""
+from .absolute_embedding import AbsolutePositionalEmbedding
+from .positional_encoding import PositionalEncoding, PositionalEncoding2D
+from .rotary_embedding import apply_rotary_pos_emb, RotaryEmbedding
diff --git a/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py b/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py
new file mode 100644
index 0000000..9466f6e
--- /dev/null
+++ b/text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py
@@ -0,0 +1,16 @@
+"""Absolute positional embedding."""
+from torch import nn, Tensor
+
+
+class AbsolutePositionalEmbedding(nn.Module):
+    def __init__(self, dim: int, max_seq_len: int) -> None:
+        super().__init__()
+        self.emb = nn.Embedding(max_seq_len, dim)
+        self._weight_init()
+
+    def _weight_init(self) -> None:
+        nn.init.normal_(self.emb.weight, std=0.02)
+
+    def forward(self, x: Tensor) -> Tensor:
+        n = torch.arange(x.shape[1], device=x.device)
+        return self.emb(n)[None, :, :]
diff --git a/text_recognizer/networks/transformer/positional_encodings/positional_encoding.py b/text_recognizer/networks/transformer/positional_encodings/positional_encoding.py
new file mode 100644
index 0000000..c50afc3
--- /dev/null
+++ b/text_recognizer/networks/transformer/positional_encodings/positional_encoding.py
@@ -0,0 +1,85 @@
+"""A positional encoding for the image features, as the transformer has no notation of the order of the sequence."""
+from einops import repeat
+import numpy as np
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class PositionalEncoding(nn.Module):
+    """Encodes a sense of distance or time for transformer networks."""
+
+    def __init__(
+        self, hidden_dim: int, dropout_rate: float, max_len: int = 1000
+    ) -> None:
+        super().__init__()
+        self.dropout = nn.Dropout(p=dropout_rate)
+        pe = self.make_pe(hidden_dim, max_len)
+        self.register_buffer("pe", pe)
+
+    @staticmethod
+    def make_pe(hidden_dim: int, max_len: int) -> Tensor:
+        """Returns positional encoding."""
+        pe = torch.zeros(max_len, hidden_dim)
+        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
+        div_term = torch.exp(
+            torch.arange(0, hidden_dim, 2).float() * (-np.log(10000.0) / hidden_dim)
+        )
+
+        pe[:, 0::2] = torch.sin(position * div_term)
+        pe[:, 1::2] = torch.cos(position * div_term)
+        pe = pe.unsqueeze(1)
+        return pe
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Encodes the tensor with a postional embedding."""
+        # [T, B, D]
+        if x.shape[2] != self.pe.shape[2]:
+            raise ValueError(f"x shape does not match pe in the 3rd dim.")
+        x = x + self.pe[: x.shape[0]]
+        return self.dropout(x)
+
+
+class PositionalEncoding2D(nn.Module):
+    """Positional encodings for feature maps."""
+
+    def __init__(self, hidden_dim: int, max_h: int = 2048, max_w: int = 2048) -> None:
+        super().__init__()
+        if hidden_dim % 2 != 0:
+            raise ValueError(f"Embedding depth {hidden_dim} is not even!")
+        self.hidden_dim = hidden_dim
+        pe = self.make_pe(hidden_dim, max_h, max_w)
+        self.register_buffer("pe", pe)
+
+    @staticmethod
+    def make_pe(hidden_dim: int, max_h: int, max_w: int) -> Tensor:
+        """Returns 2d postional encoding."""
+        pe_h = PositionalEncoding.make_pe(
+            hidden_dim // 2, max_len=max_h
+        )  # [H, 1, D // 2]
+        pe_h = repeat(pe_h, "h w d -> d h (w tile)", tile=max_w)
+
+        pe_w = PositionalEncoding.make_pe(
+            hidden_dim // 2, max_len=max_w
+        )  # [W, 1, D // 2]
+        pe_w = repeat(pe_w, "w h d -> d (h tile) w", tile=max_h)
+
+        pe = torch.cat([pe_h, pe_w], dim=0)  # [D, H, W]
+        return pe
+
+    def forward(self, x: Tensor) -> Tensor:
+        """Adds 2D postional encoding to input tensor."""
+        # Assumes x hase shape [B, D, H, W]
+        if x.shape[1] != self.pe.shape[0]:
+            raise ValueError("Hidden dimensions does not match.")
+        x += self.pe[:, : x.shape[2], : x.shape[3]]
+        return x
+
+
+def target_padding_mask(trg: Tensor, pad_index: int) -> Tensor:
+    """Returns causal target mask."""
+    trg_pad_mask = (trg != pad_index)[:, None, None]
+    trg_len = trg.shape[1]
+    trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool()
+    trg_mask = trg_pad_mask & trg_sub_mask
+    return trg_mask
diff --git a/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py
new file mode 100644
index 0000000..5e80572
--- /dev/null
+++ b/text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py
@@ -0,0 +1,39 @@
+"""Roatary embedding.
+
+Stolen from lucidrains:
+    https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py
+
+Explanation of roatary:
+    https://blog.eleuther.ai/rotary-embeddings/
+
+"""
+from typing import Tuple
+
+from einops import rearrange
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class RotaryEmbedding(nn.Module):
+    def __init__(self, dim: int):
+        super().__init__()
+        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+        self.register_buffer("inv_freq", inv_freq)
+
+    def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor:
+        t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
+        freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
+        emb = torch.cat((freqs, freqs), dim=-1)
+        return emb[None, :, :]
+
+
+def rotate_half(x: Tensor) -> Tensor:
+    x = rearrange(x, "... (j d) -> ... j d", j=2)
+    x1, x2 = x.unbind(dim=-2)
+    return torch.cat((-x2, x1), dim=-1)
+
+
+def apply_rotary_pos_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]:
+    q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k))
+    return q, k
diff --git a/text_recognizer/networks/transformer/rotary_embedding.py b/text_recognizer/networks/transformer/rotary_embedding.py
deleted file mode 100644
index 5e80572..0000000
--- a/text_recognizer/networks/transformer/rotary_embedding.py
+++ /dev/null
@@ -1,39 +0,0 @@
-"""Roatary embedding.
-
-Stolen from lucidrains:
-    https://github.com/lucidrains/x-transformers/blob/main/x_transformers/x_transformers.py
-
-Explanation of roatary:
-    https://blog.eleuther.ai/rotary-embeddings/
-
-"""
-from typing import Tuple
-
-from einops import rearrange
-import torch
-from torch import nn
-from torch import Tensor
-
-
-class RotaryEmbedding(nn.Module):
-    def __init__(self, dim: int):
-        super().__init__()
-        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
-        self.register_buffer("inv_freq", inv_freq)
-
-    def forward(self, x: Tensor, seq_dim: int = 1) -> Tensor:
-        t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
-        freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
-        emb = torch.cat((freqs, freqs), dim=-1)
-        return emb[None, :, :]
-
-
-def rotate_half(x: Tensor) -> Tensor:
-    x = rearrange(x, "... (j d) -> ... j d", j=2)
-    x1, x2 = x.unbind(dim=-2)
-    return torch.cat((-x2, x1), dim=-1)
-
-
-def apply_rotary_pos_emb(q: Tensor, k: Tensor, freqs: Tensor) -> Tuple[Tensor, Tensor]:
-    q, k = map(lambda t: (t * freqs.cos()) + (rotate_half(t) * freqs.sin()), (q, k))
-    return q, k
diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py
index d49c85a..36f86ac 100644
--- a/text_recognizer/networks/transformer/transformer.py
+++ b/text_recognizer/networks/transformer/transformer.py
@@ -1,260 +1,61 @@
-# """Transfomer module."""
-# import copy
-# from typing import Dict, Optional, Type, Union
-#
-# import numpy as np
-# import torch
-# from torch import nn
-# from torch import Tensor
-# import torch.nn.functional as F
-#
-# from text_recognizer.networks.transformer.attention import MultiHeadAttention
-# from text_recognizer.networks.util import activation_function
-#
-#
-# class GEGLU(nn.Module):
-#     """GLU activation for improving feedforward activations."""
-#
-#     def __init__(self, dim_in: int, dim_out: int) -> None:
-#         super().__init__()
-#         self.proj = nn.Linear(dim_in, dim_out * 2)
-#
-#     def forward(self, x: Tensor) -> Tensor:
-#         """Forward propagation."""
-#         x, gate = self.proj(x).chunk(2, dim=-1)
-#         return x * F.gelu(gate)
-#
-#
-# def _get_clones(module: Type[nn.Module], num_layers: int) -> nn.ModuleList:
-#     return nn.ModuleList([copy.deepcopy(module) for _ in range(num_layers)])
-#
-#
-# class _IntraLayerConnection(nn.Module):
-#     """Preforms the residual connection inside the transfomer blocks and applies layernorm."""
-#
-#     def __init__(self, dropout_rate: float, hidden_dim: int) -> None:
-#         super().__init__()
-#         self.norm = nn.LayerNorm(normalized_shape=hidden_dim)
-#         self.dropout = nn.Dropout(p=dropout_rate)
-#
-#     def forward(self, src: Tensor, residual: Tensor) -> Tensor:
-#         return self.norm(self.dropout(src) + residual)
-#
-#
-# class FeedForward(nn.Module):
-#     def __init__(
-#         self,
-#         hidden_dim: int,
-#         expansion_dim: int,
-#         dropout_rate: float,
-#         activation: str = "relu",
-#     ) -> None:
-#         super().__init__()
-#
-#         in_projection = (
-#             nn.Sequential(
-#                 nn.Linear(hidden_dim, expansion_dim), activation_function(activation)
-#             )
-#             if activation != "glu"
-#             else GEGLU(hidden_dim, expansion_dim)
-#         )
-#
-#         self.layer = nn.Sequential(
-#             in_projection,
-#             nn.Dropout(p=dropout_rate),
-#             nn.Linear(in_features=expansion_dim, out_features=hidden_dim),
-#         )
-#
-#     def forward(self, x: Tensor) -> Tensor:
-#         return self.layer(x)
-#
-#
-# class EncoderLayer(nn.Module):
-#     """Transfomer encoding layer."""
-#
-#     def __init__(
-#         self,
-#         hidden_dim: int,
-#         num_heads: int,
-#         expansion_dim: int,
-#         dropout_rate: float,
-#         activation: str = "relu",
-#     ) -> None:
-#         super().__init__()
-#         self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
-#         self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation)
-#         self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
-#         self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
-#
-#     def forward(self, src: Tensor, mask: Optional[Tensor] = None) -> Tensor:
-#         """Forward pass through the encoder."""
-#         # First block.
-#         # Multi head attention.
-#         out, _ = self.self_attention(src, src, src, mask)
-#
-#         # Add & norm.
-#         out = self.block1(out, src)
-#
-#         # Second block.
-#         # Apply 1D-convolution.
-#         mlp_out = self.mlp(out)
-#
-#         # Add & norm.
-#         out = self.block2(mlp_out, out)
-#
-#         return out
-#
-#
-# class Encoder(nn.Module):
-#     """Transfomer encoder module."""
-#
-#     def __init__(
-#         self,
-#         num_layers: int,
-#         encoder_layer: Type[nn.Module],
-#         norm: Optional[Type[nn.Module]] = None,
-#     ) -> None:
-#         super().__init__()
-#         self.layers = _get_clones(encoder_layer, num_layers)
-#         self.norm = norm
-#
-#     def forward(self, src: Tensor, src_mask: Optional[Tensor] = None) -> Tensor:
-#         """Forward pass through all encoder layers."""
-#         for layer in self.layers:
-#             src = layer(src, src_mask)
-#
-#         if self.norm is not None:
-#             src = self.norm(src)
-#
-#         return src
-#
-#
-# class DecoderLayer(nn.Module):
-#     """Transfomer decoder layer."""
-#
-#     def __init__(
-#         self,
-#         hidden_dim: int,
-#         num_heads: int,
-#         expansion_dim: int,
-#         dropout_rate: float = 0.0,
-#         activation: str = "relu",
-#     ) -> None:
-#         super().__init__()
-#         self.hidden_dim = hidden_dim
-#         self.self_attention = MultiHeadAttention(hidden_dim, num_heads, dropout_rate)
-#         self.multihead_attention = MultiHeadAttention(
-#             hidden_dim, num_heads, dropout_rate
-#         )
-#         self.mlp = FeedForward(hidden_dim, expansion_dim, dropout_rate, activation)
-#         self.block1 = _IntraLayerConnection(dropout_rate, hidden_dim)
-#         self.block2 = _IntraLayerConnection(dropout_rate, hidden_dim)
-#         self.block3 = _IntraLayerConnection(dropout_rate, hidden_dim)
-#
-#     def forward(
-#         self,
-#         trg: Tensor,
-#         memory: Tensor,
-#         trg_mask: Optional[Tensor] = None,
-#         memory_mask: Optional[Tensor] = None,
-#     ) -> Tensor:
-#         """Forward pass of the layer."""
-#         out, _ = self.self_attention(trg, trg, trg, trg_mask)
-#         trg = self.block1(out, trg)
-#
-#         out, _ = self.multihead_attention(trg, memory, memory, memory_mask)
-#         trg = self.block2(out, trg)
-#
-#         out = self.mlp(trg)
-#         out = self.block3(out, trg)
-#
-#         return out
-#
-#
-# class Decoder(nn.Module):
-#     """Transfomer decoder module."""
-#
-#     def __init__(
-#         self,
-#         decoder_layer: Type[nn.Module],
-#         num_layers: int,
-#         norm: Optional[Type[nn.Module]] = None,
-#     ) -> None:
-#         super().__init__()
-#         self.layers = _get_clones(decoder_layer, num_layers)
-#         self.num_layers = num_layers
-#         self.norm = norm
-#
-#     def forward(
-#         self,
-#         trg: Tensor,
-#         memory: Tensor,
-#         trg_mask: Optional[Tensor] = None,
-#         memory_mask: Optional[Tensor] = None,
-#     ) -> Tensor:
-#         """Forward pass through the decoder."""
-#         for layer in self.layers:
-#             trg = layer(trg, memory, trg_mask, memory_mask)
-#
-#         if self.norm is not None:
-#             trg = self.norm(trg)
-#
-#         return trg
-#
-#
-# class Transformer(nn.Module):
-#     """Transformer network."""
-#
-#     def __init__(
-#         self,
-#         num_encoder_layers: int,
-#         num_decoder_layers: int,
-#         hidden_dim: int,
-#         num_heads: int,
-#         expansion_dim: int,
-#         dropout_rate: float,
-#         activation: str = "relu",
-#     ) -> None:
-#         super().__init__()
-#
-#         # Configure encoder.
-#         encoder_norm = nn.LayerNorm(hidden_dim)
-#         encoder_layer = EncoderLayer(
-#             hidden_dim, num_heads, expansion_dim, dropout_rate, activation
-#         )
-#         self.encoder = Encoder(num_encoder_layers, encoder_layer, encoder_norm)
-#
-#         # Configure decoder.
-#         decoder_norm = nn.LayerNorm(hidden_dim)
-#         decoder_layer = DecoderLayer(
-#             hidden_dim, num_heads, expansion_dim, dropout_rate, activation
-#         )
-#         self.decoder = Decoder(decoder_layer, num_decoder_layers, decoder_norm)
-#
-#         self._reset_parameters()
-#
-#     def _reset_parameters(self) -> None:
-#         for p in self.parameters():
-#             if p.dim() > 1:
-#                 nn.init.xavier_uniform_(p)
-#
-#     def forward(
-#         self,
-#         src: Tensor,
-#         trg: Tensor,
-#         src_mask: Optional[Tensor] = None,
-#         trg_mask: Optional[Tensor] = None,
-#         memory_mask: Optional[Tensor] = None,
-#     ) -> Tensor:
-#         """Forward pass through the transformer."""
-#         if src.shape[0] != trg.shape[0]:
-#             print(trg.shape)
-#             raise RuntimeError("The batch size of the src and trg must be the same.")
-#         if src.shape[2] != trg.shape[2]:
-#             raise RuntimeError(
-#                 "The number of features for the src and trg must be the same."
-#             )
-#
-#         memory = self.encoder(src, src_mask)
-#         output = self.decoder(trg, memory, trg_mask, memory_mask)
-#         return output
+"""Transformer wrapper."""
+from typing import Optional, Type
+
+from torch import nn, Tensor
+
+from .layers import AttentionLayers
+from text_recognizer.networks.transformer.positional_encodings import (
+    AbsolutePositionalEmbedding,
+)
+
+
+class Transformer(nn.Module):
+    def __init__(
+        self,
+        num_tokens: int,
+        max_seq_len: int,
+        attn_layers: Type[AttentionLayers],
+        emb_dim: Optional[int] = None,
+        emb_dropout: float = 0.0,
+        use_pos_emb: bool = True,
+    ) -> None:
+        dim = attn_layers.dim
+        emb_dim = emb_dim if emb_dim is not None else dim
+        self.max_seq_len = max_seq_len
+
+        self.token_emb = nn.Embedding(num_tokens, emb_dim)
+        self.emb_dropout = nn.Dropout(emb_dropout)
+        self.pos_emb = (
+            AbsolutePositionalEmbedding(emb_dim, max_seq_len)
+            if (use_pos_emb and not self.attn_layers.has_pos_emb)
+            else None
+        )
+
+        self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
+        self.attn_layers = attn_layers
+        self.norm = nn.LayerNorm(dim)
+
+        self._init_weights()
+
+        self.logits = nn.Linear(dim, num_tokens)
+
+    def _init_weights(self) -> None:
+        nn.init.normal_(self.token_emb.weight, std=0.02)
+
+    def forward(
+        self,
+        x: Tensor,
+        mask: Optional[Tensor],
+        return_embeddings: bool = False,
+        **kwargs: Any
+    ) -> Tensor:
+        b, n, device = *x.shape, x.device
+        x += self.token_emb(x)
+        if self.pos_emb is not None:
+            x += self.pos_emb(x)
+        x = self.emb_dropout(x)
+
+        x = self.project_emb(x)
+        x = self.attn_layers(x, mask=mask, **kwargs)
+        out = self.logits(x) if not return_embeddings else x
+        return x
diff --git a/text_recognizer/networks/transformer/vit.py b/text_recognizer/networks/transformer/vit.py
index e69de29..ab331f8 100644
--- a/text_recognizer/networks/transformer/vit.py
+++ b/text_recognizer/networks/transformer/vit.py
@@ -0,0 +1,46 @@
+"""Vision Transformer."""
+from typing import Tuple, Type
+
+from einops.layers.torch import Rearrange
+import torch
+from torch import nn, Tensor
+
+
+class ViT(nn.Module):
+    def __init__(
+        self,
+        image_size: Tuple[int, int],
+        patch_size: Tuple[int, int],
+        dim: int,
+        transformer: Type[nn.Module],
+        channels: int = 1,
+    ) -> None:
+        super().__init__()
+        img_height, img_width = image_size
+        patch_height, patch_width = patch_size
+        assert img_height % patch_height == 0
+        assert img_width % patch_width == 0
+
+        num_patches = (img_height // patch_height) * (img_width // patch_width)
+        patch_dim = channels * patch_height * patch_width
+
+        self.to_patch_embedding = nn.Sequential(
+            Rearrange(
+                "b c (h p1) (w p2) -> b (h w) (p1 p2 c)",
+                p1=patch_height,
+                p2=patch_width,
+                c=channels,
+            ),
+            nn.Linear(patch_dim, dim),
+        )
+
+        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
+        self.transformer = transformer
+        self.norm = nn.LayerNorm(dim)
+
+    def forward(self, img: Tensor) -> Tensor:
+        x = self.to_patch_embedding(img)
+        _, n, _ = x.shape
+        x += self.pos_embedding[:, :n]
+        x = self.transformer(x)
+        return x
-- 
cgit v1.2.3-70-g09d2