From d717f9e3e4dd17351f08c5822cb90d055c4513cc Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Thu, 8 Jul 2021 22:24:28 +0200
Subject: Working on cnn transformer module

---
 text_recognizer/networks/cnn_tranformer.py | 123 ++++++++++++++++++++++++++++-
 1 file changed, 121 insertions(+), 2 deletions(-)

(limited to 'text_recognizer/networks')

diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py
index da69311..38de0ba 100644
--- a/text_recognizer/networks/cnn_tranformer.py
+++ b/text_recognizer/networks/cnn_tranformer.py
@@ -10,5 +10,124 @@ class CnnTransformer(nn.Module):
     def __attrs_pre_init__(self) -> None:
         super().__init__()
 
-    backbone: Type[nn.Module] = attr.ib()
-    head = Type[nn.Module] = attr.ib()
+    # Parameters,
+    input_dims: Tuple[int, int, int] = attr.ib()
+    hidden_dim: int = attr.ib()
+    dropout_rate: float = attr.ib()
+    max_output_len: int = attr.ib()
+    num_classes: int = attr.ib()
+
+    # Modules.
+    encoder: Type[nn.Module] = attr.ib()
+    decoder: Type[nn.Module] = attr.ib()
+    embedding: nn.Embedding = attr.ib(init=False, default=None)
+    latent_encoder: nn.Sequential = attr.ib(init=False, default=None)
+    token_embedding: nn.Embedding = attr.ib(init=False, default=None)
+    token_pos_encoder: PositionalEncoding = attr.ib(init=False, default=None)
+    head: nn.Linear = attr.ib(init=False, default=None)
+
+    def __attrs_post_init__(self) -> None:
+        # Latent projector for down sampling number of filters and 2d
+        # positional encoding.
+        self.latent_encoder = nn.Sequential(
+            nn.Conv2d(
+                in_channels=self.encoder.out_channels,
+                out_channels=self.hidden_dim,
+                kernel_size=1,
+            ),
+            PositionalEncoding2D(
+                hidden_dim=self.hidden_dim,
+                max_h=self.input_dims[1],
+                max_w=self.input_dims[2],
+            ),
+            nn.Flatten(start_dim=2),
+        )
+
+        # Token embedding.
+        self.token_embedding = nn.Embedding(
+            num_embeddings=self.num_classes, embedding_dim=self.hidden_dim
+        )
+
+        # Positional encoding for decoder tokens.
+        self.token_pos_encoder = PositionalEncoding(
+            hidden_dim=self.hidden_dim, dropout_rate=self.dropout_rate
+        )
+        # Head
+        self.head = nn.Linear(
+            in_features=self.hidden_dim, out_features=self.num_classes
+        )
+
+        # Initalize weights for encoder.
+        self.init_weights()
+
+    def init_weights(self) -> None:
+        """Initalize weights for decoder network and head."""
+        bound = 0.1
+        self.token_embedding.weight.data.uniform_(-bound, bound)
+        self.head.bias.data.zero_()
+        self.head.weight.data.uniform_(-bound, bound)
+        # TODO: Initalize encoder?
+
+    def encode(self, x: Tensor) -> Tensor:
+        """Encodes an image into a latent feature vector.
+
+        Args:
+            x (Tensor): Image tensor.
+
+        Shape:
+            - x: :math: `(B, C, H, W)`
+            - z: :math: `(B, Sx, E)`
+
+            where Sx is the length of the flattened feature maps projected from
+            the encoder. E latent dimension for each pixel in the projected
+            feature maps.
+
+        Returns:
+            Tensor: A Latent embedding of the image.
+        """
+        z = self.encoder(x)
+        z = self.latent_encoder(z)
+
+        # Permute tensor from [B, E, Ho * Wo] to [Sx, B, E]
+        z = z.permute(2, 0, 1)
+        return z
+
+    def decode(self, z: Tensor, trg: Tensor) -> Tensor:
+        """Decodes latent images embedding into word pieces.
+
+        Args:
+            z (Tensor): Latent images embedding.
+            trg (Tensor): Word embeddings.
+
+        Shapes:
+            - z: :math: `(B, Sx, E)`
+            - out: :math: `(B, Sy, T)`
+
+            where Sy is the length of the output and T is the number of tokens.
+
+        Returns:
+            Tensor: Sequence of word piece embeddings.
+        """
+        pass
+
+    def forward(self, x: Tensor, trg: Tensor) -> Tensor:
+        """Encodes images into word piece logtis.
+
+        Args:
+            x (Tensor): Input image(s).
+            trg (Tensor): Target word embeddings.
+
+        Shapes:
+            - x: :math: `(B, C, H, W)`
+            - trg: :math: `(B, Sy, T)`
+
+            where B is the batch size, C is the number of input channels, H is
+            the image height and W is the image width.
+        """
+        z = self.encode(x)
+        y = self.decode(z, trg)
+        return y
+
+    def predict(self, x: Tensor) -> Tensor:
+        """Predicts text in image."""
+        pass
-- 
cgit v1.2.3-70-g09d2