From bd4bd443f339e95007bfdabf3e060db720f4d4b9 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Tue, 3 Aug 2021 18:18:48 +0200
Subject: Training working, multiple bug fixes

---
 text_recognizer/networks/conv_transformer.py       | 42 ++++++++++------------
 .../networks/encoders/efficientnet/mbconv.py       |  9 ++---
 text_recognizer/networks/transformer/layers.py     | 27 ++++++--------
 3 files changed, 32 insertions(+), 46 deletions(-)

(limited to 'text_recognizer/networks')

diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index 09cc654..f3ba49d 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -2,7 +2,6 @@
 import math
 from typing import Tuple
 
-import attr
 from torch import nn, Tensor
 
 from text_recognizer.networks.encoders.efficientnet import EfficientNet
@@ -13,32 +12,28 @@ from text_recognizer.networks.transformer.positional_encodings import (
 )
 
 
-@attr.s(eq=False)
 class ConvTransformer(nn.Module):
     """Convolutional encoder and transformer decoder network."""
 
-    def __attrs_pre_init__(self) -> None:
+    def __init__(
+        self,
+        input_dims: Tuple[int, int, int],
+        hidden_dim: int,
+        dropout_rate: float,
+        num_classes: int,
+        pad_index: Tensor,
+        encoder: EfficientNet,
+        decoder: Decoder,
+    ) -> None:
         super().__init__()
+        self.input_dims = input_dims
+        self.hidden_dim = hidden_dim
+        self.dropout_rate = dropout_rate
+        self.num_classes = num_classes
+        self.pad_index = pad_index
+        self.encoder = encoder
+        self.decoder = decoder
 
-    # Parameters and placeholders,
-    input_dims: Tuple[int, int, int] = attr.ib()
-    hidden_dim: int = attr.ib()
-    dropout_rate: float = attr.ib()
-    max_output_len: int = attr.ib()
-    num_classes: int = attr.ib()
-    pad_index: Tensor = attr.ib()
-
-    # Modules.
-    encoder: EfficientNet = attr.ib()
-    decoder: Decoder = attr.ib()
-
-    latent_encoder: nn.Sequential = attr.ib(init=False)
-    token_embedding: nn.Embedding = attr.ib(init=False)
-    token_pos_encoder: PositionalEncoding = attr.ib(init=False)
-    head: nn.Linear = attr.ib(init=False)
-
-    def __attrs_post_init__(self) -> None:
-        """Post init configuration."""
         # Latent projector for down sampling number of filters and 2d
         # positional encoding.
         self.latent_encoder = nn.Sequential(
@@ -126,7 +121,8 @@ class ConvTransformer(nn.Module):
         context = self.token_embedding(context) * math.sqrt(self.hidden_dim)
         context = self.token_pos_encoder(context)
         out = self.decoder(x=context, context=z, mask=context_mask)
-        logits = self.head(out)
+        logits = self.head(out)  # [B, Sy, T]
+        logits = logits.permute(0, 2, 1)  # [B, T, Sy]
         return logits
 
     def forward(self, x: Tensor, context: Tensor) -> Tensor:
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py
index e85df87..7bfd9ba 100644
--- a/text_recognizer/networks/encoders/efficientnet/mbconv.py
+++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py
@@ -11,9 +11,7 @@ from text_recognizer.networks.encoders.efficientnet.utils import stochastic_dept
 
 def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]:
     """Converts int to tuple."""
-    return (
-        (stride,) * 2 if isinstance(stride, int) else stride
-    )
+    return (stride,) * 2 if isinstance(stride, int) else stride
 
 
 @attr.s(eq=False)
@@ -41,10 +39,7 @@ class MBConvBlock(nn.Module):
     def _configure_padding(self) -> Tuple[int, int, int, int]:
         """Set padding for convolutional layers."""
         if self.stride == (2, 2):
-            return (
-                (self.kernel_size - 1) // 2 - 1,
-                (self.kernel_size - 1) // 2,
-                ) * 2
+            return ((self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2,) * 2
         return ((self.kernel_size - 1) // 2,) * 4
 
     def __attrs_post_init__(self) -> None:
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index ce443e5..70a0ac7 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -1,5 +1,4 @@
 """Transformer attention layer."""
-from functools import partial
 from typing import Any, Dict, Optional, Tuple
 
 import attr
@@ -27,25 +26,17 @@ class AttentionLayers(nn.Module):
     norm_fn: str = attr.ib()
     ff_fn: str = attr.ib()
     ff_kwargs: Dict = attr.ib()
+    rotary_emb: Optional[RotaryEmbedding] = attr.ib()
     causal: bool = attr.ib(default=False)
     cross_attend: bool = attr.ib(default=False)
     pre_norm: bool = attr.ib(default=True)
-    rotary_emb: Optional[RotaryEmbedding] = attr.ib(default=None)
     layer_types: Tuple[str, ...] = attr.ib(init=False)
     layers: nn.ModuleList = attr.ib(init=False)
-    attn: partial = attr.ib(init=False)
-    norm: partial = attr.ib(init=False)
-    ff: partial = attr.ib(init=False)
 
     def __attrs_post_init__(self) -> None:
         """Post init configuration."""
         self.layer_types = self._get_layer_types() * self.depth
-        attn = load_partial_fn(
-            self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs
-        )
-        norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim)
-        ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs)
-        self.layers = self._build_network(attn, norm, ff)
+        self.layers = self._build_network()
 
     def _get_layer_types(self) -> Tuple:
         """Get layer specification."""
@@ -53,10 +44,13 @@ class AttentionLayers(nn.Module):
             return "a", "c", "f"
         return "a", "f"
 
-    def _build_network(
-        self, attn: partial, norm: partial, ff: partial,
-    ) -> nn.ModuleList:
+    def _build_network(self) -> nn.ModuleList:
         """Configures transformer network."""
+        attn = load_partial_fn(
+            self.attn_fn, dim=self.dim, num_heads=self.num_heads, **self.attn_kwargs
+        )
+        norm = load_partial_fn(self.norm_fn, normalized_shape=self.dim)
+        ff = load_partial_fn(self.ff_fn, dim=self.dim, **self.ff_kwargs)
         layers = nn.ModuleList([])
         for layer_type in self.layer_types:
             if layer_type == "a":
@@ -106,6 +100,7 @@ class Encoder(AttentionLayers):
     causal: bool = attr.ib(default=False, init=False)
 
 
-@attr.s(auto_attribs=True, eq=False)
 class Decoder(AttentionLayers):
-    causal: bool = attr.ib(default=True, init=False)
+    def __init__(self, **kwargs: Any) -> None:
+        assert "causal" not in kwargs, "Cannot set causality on decoder"
+        super().__init__(causal=True, **kwargs)
-- 
cgit v1.2.3-70-g09d2