From 9353a39a18d0542afc177cd134f33f0756820a7d Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Thu, 9 Jun 2022 22:33:34 +0200
Subject: Remove abstract lightning module

---
 text_recognizer/networks/__init__.py         |   1 +
 text_recognizer/networks/base.py             | 102 ---------------------------
 text_recognizer/networks/cnn.py              |  26 -------
 text_recognizer/networks/conv_transformer.py |  99 +++++++++++++++++++++-----
 4 files changed, 82 insertions(+), 146 deletions(-)
 delete mode 100644 text_recognizer/networks/base.py
 delete mode 100644 text_recognizer/networks/cnn.py

(limited to 'text_recognizer/networks')

diff --git a/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py
index d9ef58b..f921882 100644
--- a/text_recognizer/networks/__init__.py
+++ b/text_recognizer/networks/__init__.py
@@ -1 +1,2 @@
 """Network modules"""
+from text_recognizer.networks.conv_transformer import ConvTransformer
diff --git a/text_recognizer/networks/base.py b/text_recognizer/networks/base.py
deleted file mode 100644
index 29c3bbc..0000000
--- a/text_recognizer/networks/base.py
+++ /dev/null
@@ -1,102 +0,0 @@
-"""Base network module."""
-import math
-from typing import Optional, Tuple, Type
-
-from loguru import logger as log
-from torch import nn, Tensor
-
-from text_recognizer.networks.transformer.decoder import Decoder
-
-
-class BaseTransformer(nn.Module):
-    """Base transformer network."""
-
-    def __init__(
-        self,
-        input_dims: Tuple[int, int, int],
-        hidden_dim: int,
-        num_classes: int,
-        pad_index: Tensor,
-        encoder: Type[nn.Module],
-        decoder: Decoder,
-        token_pos_embedding: Optional[Type[nn.Module]] = None,
-    ) -> None:
-        super().__init__()
-        self.input_dims = input_dims
-        self.hidden_dim = hidden_dim
-        self.num_classes = num_classes
-        self.pad_index = pad_index
-        self.encoder = encoder
-        self.decoder = decoder
-
-        # Token embedding.
-        self.token_embedding = nn.Embedding(
-            num_embeddings=self.num_classes, embedding_dim=self.hidden_dim
-        )
-
-        # Positional encoding for decoder tokens.
-        if not self.decoder.has_pos_emb:
-            self.token_pos_embedding = token_pos_embedding
-        else:
-            self.token_pos_embedding = None
-            log.debug("Decoder already have a positional embedding.")
-
-        # Output layer
-        self.to_logits = nn.Linear(
-            in_features=self.hidden_dim, out_features=self.num_classes
-        )
-
-    def encode(self, x: Tensor) -> Tensor:
-        """Encodes images with encoder."""
-        return self.encoder(x)
-
-    def decode(self, src: Tensor, trg: Tensor) -> Tensor:
-        """Decodes latent images embedding into word pieces.
-
-        Args:
-            src (Tensor): Latent images embedding.
-            trg (Tensor): Word embeddings.
-
-        Shapes:
-            - z: :math: `(B, Sx, E)`
-            - context: :math: `(B, Sy)`
-            - out: :math: `(B, Sy, T)`
-
-            where Sy is the length of the output and T is the number of tokens.
-
-        Returns:
-            Tensor: Sequence of word piece embeddings.
-        """
-        trg = trg.long()
-        trg_mask = trg != self.pad_index
-        trg = self.token_embedding(trg) * math.sqrt(self.hidden_dim)
-        trg = (
-            self.token_pos_embedding(trg)
-            if self.token_pos_embedding is not None
-            else trg
-        )
-        out = self.decoder(x=trg, context=src, input_mask=trg_mask)
-        logits = self.to_logits(out)  # [B, Sy, T]
-        logits = logits.permute(0, 2, 1)  # [B, T, Sy]
-        return logits
-
-    def forward(self, x: Tensor, context: Tensor) -> Tensor:
-        """Encodes images into word piece logtis.
-
-        Args:
-            x (Tensor): Input image(s).
-            context (Tensor): Target word embeddings.
-
-        Shapes:
-            - x: :math: `(B, C, H, W)`
-            - context: :math: `(B, Sy, T)`
-
-            where B is the batch size, C is the number of input channels, H is
-            the image height and W is the image width.
-
-        Returns:
-            Tensor: Sequence of logits.
-        """
-        z = self.encode(x)
-        logits = self.decode(z, context)
-        return logits
diff --git a/text_recognizer/networks/cnn.py b/text_recognizer/networks/cnn.py
deleted file mode 100644
index 5e2a7f4..0000000
--- a/text_recognizer/networks/cnn.py
+++ /dev/null
@@ -1,26 +0,0 @@
-"""Simple convolutional network."""
-import torch
-from torch import nn, Tensor
-
-
-class CNN(nn.Module):
-    def __init__(self, channels: int, depth: int) -> None:
-        super().__init__()
-        self.layers = self._build(channels, depth)
-
-    def _build(self, channels: int, depth: int) -> nn.Sequential:
-        layers = []
-        for i in range(depth):
-            layers.append(
-                nn.Conv2d(
-                    in_channels=1 if i == 0 else channels,
-                    out_channels=channels,
-                    kernel_size=3,
-                    stride=2,
-                )
-            )
-            layers.append(nn.Mish(inplace=True))
-        return nn.Sequential(*layers)
-
-    def forward(self, x: Tensor) -> Tensor:
-        return self.layers(x)
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index e374bd8..d66643b 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -1,17 +1,17 @@
-"""Vision transformer for character recognition."""
+"""Base network module."""
 from typing import Optional, Tuple, Type
 
+from loguru import logger as log
 from torch import nn, Tensor
 
-from text_recognizer.networks.base import BaseTransformer
 from text_recognizer.networks.transformer.decoder import Decoder
 from text_recognizer.networks.transformer.embeddings.axial import (
     AxialPositionalEmbedding,
 )
 
 
-class ConvTransformer(BaseTransformer):
-    """Convolutional encoder and transformer decoder network."""
+class ConvTransformer(nn.Module):
+    """Base transformer network."""
 
     def __init__(
         self,
@@ -21,20 +21,30 @@ class ConvTransformer(BaseTransformer):
         pad_index: Tensor,
         encoder: Type[nn.Module],
         decoder: Decoder,
-        pixel_pos_embedding: AxialPositionalEmbedding,
+        pixel_embedding: AxialPositionalEmbedding,
         token_pos_embedding: Optional[Type[nn.Module]] = None,
     ) -> None:
-        super().__init__(
-            input_dims,
-            hidden_dim,
-            num_classes,
-            pad_index,
-            encoder,
-            decoder,
-            token_pos_embedding,
+        super().__init__()
+        self.input_dims = input_dims
+        self.hidden_dim = hidden_dim
+        self.num_classes = num_classes
+        self.pad_index = pad_index
+        self.encoder = encoder
+        self.decoder = decoder
+
+        # Token embedding.
+        self.token_embedding = nn.Embedding(
+            num_embeddings=self.num_classes, embedding_dim=self.hidden_dim
         )
 
-        self.pixel_pos_embedding = pixel_pos_embedding
+        # Positional encoding for decoder tokens.
+        if not self.decoder.has_pos_emb:
+            self.token_pos_embedding = token_pos_embedding
+        else:
+            self.token_pos_embedding = None
+            log.debug("Decoder already have a positional embedding.")
+
+        self.pixel_embedding = pixel_embedding
 
         # Latent projector for down sampling number of filters and 2d
         # positional encoding.
@@ -44,15 +54,17 @@ class ConvTransformer(BaseTransformer):
             kernel_size=1,
         )
 
+        # Output layer
+        self.to_logits = nn.Linear(
+            in_features=self.hidden_dim, out_features=self.num_classes
+        )
+
         # Initalize weights for encoder.
         self.init_weights()
 
     def init_weights(self) -> None:
         """Initalize weights for decoder network and to_logits."""
-        bound = 0.1
-        self.token_embedding.weight.data.uniform_(-bound, bound)
-        self.to_logits.bias.data.zero_()
-        self.to_logits.weight.data.uniform_(-bound, bound)
+        nn.init.kaiming_normal_(self.token_emb.emb.weight)
 
     def encode(self, x: Tensor) -> Tensor:
         """Encodes an image into a latent feature vector.
@@ -79,3 +91,54 @@ class ConvTransformer(BaseTransformer):
         # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
         z = z.permute(0, 2, 1)
         return z
+
+    def decode(self, src: Tensor, trg: Tensor) -> Tensor:
+        """Decodes latent images embedding into word pieces.
+
+        Args:
+            src (Tensor): Latent images embedding.
+            trg (Tensor): Word embeddings.
+
+        Shapes:
+            - z: :math: `(B, Sx, D)`
+            - context: :math: `(B, Sy)`
+            - out: :math: `(B, Sy, C)`
+
+            where Sy is the length of the output and C is the number of classes.
+
+        Returns:
+            Tensor: Sequence of word piece embeddings.
+        """
+        trg = trg.long()
+        trg_mask = trg != self.pad_index
+        trg = self.token_embedding(trg)
+        trg = (
+            self.token_pos_embedding(trg)
+            if self.token_pos_embedding is not None
+            else trg
+        )
+        out = self.decoder(x=trg, context=src, input_mask=trg_mask)
+        logits = self.to_logits(out)  # [B, Sy, C]
+        logits = logits.permute(0, 2, 1)  # [B, C, Sy]
+        return logits
+
+    def forward(self, x: Tensor, context: Tensor) -> Tensor:
+        """Encodes images into word piece logtis.
+
+        Args:
+            x (Tensor): Input image(s).
+            context (Tensor): Target word embeddings.
+
+        Shapes:
+            - x: :math: `(B, D, H, W)`
+            - context: :math: `(B, Sy, C)`
+
+            where B is the batch size, D is the number of input channels, H is
+            the image height, W is the image width, and C is the number of classes.
+
+        Returns:
+            Tensor: Sequence of logits.
+        """
+        z = self.encode(x)
+        logits = self.decode(z, context)
+        return logits
-- 
cgit v1.2.3-70-g09d2