From 89be047a46c8e88511d301f63d7f6795fe04ab81 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sat, 27 Nov 2021 12:45:19 +0100
Subject: Revert "Remove default transformer"

This reverts commit b3d3b7ddc0796e98d78561bc5ca22728dc0372b0.
---
 .../networks/transformer/transformer.py            | 62 ++++++++++++++++++++++
 1 file changed, 62 insertions(+)
 create mode 100644 text_recognizer/networks/transformer/transformer.py

(limited to 'text_recognizer/networks/transformer')

diff --git a/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py
new file mode 100644
index 0000000..31088b4
--- /dev/null
+++ b/text_recognizer/networks/transformer/transformer.py
@@ -0,0 +1,62 @@
+"""Transformer wrapper."""
+from typing import Any, Optional, Type
+
+from torch import nn, Tensor
+
+from .layers import AttentionLayers
+from text_recognizer.networks.transformer.positional_encodings import (
+    AbsolutePositionalEmbedding,
+)
+
+
+class Transformer(nn.Module):
+    def __init__(
+        self,
+        num_tokens: int,
+        max_seq_len: int,
+        attn_layers: Type[AttentionLayers],
+        emb_dim: Optional[int] = None,
+        emb_dropout: float = 0.0,
+        use_pos_emb: bool = True,
+    ) -> None:
+        super().__init__()
+        dim = attn_layers.dim
+        self.attn_layers = attn_layers
+        emb_dim = emb_dim if emb_dim is not None else dim
+        self.max_seq_len = max_seq_len
+
+        self.token_emb = nn.Embedding(num_tokens, emb_dim)
+        self.emb_dropout = nn.Dropout(emb_dropout)
+        self.pos_emb = (
+            AbsolutePositionalEmbedding(emb_dim, max_seq_len)
+            if (use_pos_emb and not self.attn_layers.has_pos_emb)
+            else None
+        )
+
+        self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
+        self.norm = nn.LayerNorm(dim)
+
+        self._init_weights()
+
+        self.logits = nn.Linear(dim, num_tokens)
+
+    def _init_weights(self) -> None:
+        nn.init.normal_(self.token_emb.weight, std=0.02)
+
+    def forward(
+        self,
+        x: Tensor,
+        mask: Optional[Tensor] = None,
+        return_embeddings: bool = False,
+        **kwargs: Any
+    ) -> Tensor:
+        b, n, device = *x.shape, x.device
+        x = self.token_emb(x)
+        if self.pos_emb is not None:
+            x += self.pos_emb(x)
+        x = self.emb_dropout(x)
+
+        x = self.project_emb(x)
+        x = self.attn_layers(x, mask=mask, **kwargs)
+        out = self.logits(x) if not return_embeddings else x
+        return out
-- 
cgit v1.2.3-70-g09d2