From 31e127c479cac61134bed3d5c4341561eef68c52 Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Tue, 13 Sep 2022 19:07:06 +0200
Subject: Update conv transformer notebook

---
 notebooks/04-conv-transformer.ipynb | 147 +++++++++++-------------------------
 1 file changed, 46 insertions(+), 101 deletions(-)

diff --git a/notebooks/04-conv-transformer.ipynb b/notebooks/04-conv-transformer.ipynb
index 8ded6b6..50779a9 100644
--- a/notebooks/04-conv-transformer.ipynb
+++ b/notebooks/04-conv-transformer.ipynb
@@ -40,7 +40,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 7,
+   "execution_count": 26,
    "id": "3cf50475-39f2-4642-a7d1-5bcbc0a036f7",
    "metadata": {},
    "outputs": [],
@@ -50,7 +50,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 35,
+   "execution_count": 42,
    "id": "e52ecb01-c975-4e55-925d-1182c7aea473",
    "metadata": {},
    "outputs": [],
@@ -61,17 +61,17 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 36,
+   "execution_count": 43,
    "id": "f939aa37-7b1d-45cc-885c-323c4540bda1",
    "metadata": {},
    "outputs": [
     {
      "data": {
       "text/plain": [
-       "{'_target_': 'text_recognizer.networks.ConvTransformer', 'input_dims': [1, 1, 576, 640], 'hidden_dim': 144, 'num_classes': 58, 'pad_index': 3, 'encoder': {'_target_': 'text_recognizer.networks.EfficientNet', 'arch': 'b0', 'stochastic_dropout_rate': 0.2, 'bn_momentum': 0.99, 'bn_eps': 0.001, 'depth': 5, 'out_channels': 144}, 'decoder': {'_target_': 'text_recognizer.networks.transformer.Decoder', 'depth': 6, 'block': {'_target_': 'text_recognizer.networks.transformer.DecoderBlock', 'self_attn': {'_target_': 'text_recognizer.networks.transformer.Attention', 'dim': 144, 'num_heads': 12, 'dim_head': 64, 'dropout_rate': 0.2, 'causal': True, 'rotary_embedding': {'_target_': 'text_recognizer.networks.transformer.RotaryEmbedding', 'dim': 64}}, 'cross_attn': {'_target_': 'text_recognizer.networks.transformer.Attention', 'dim': 144, 'num_heads': 8, 'dim_head': 64, 'dropout_rate': 0.2, 'causal': False}, 'norm': {'_target_': 'text_recognizer.networks.transformer.RMSNorm', 'dim': 144}, 'ff': {'_target_': 'text_recognizer.networks.transformer.FeedForward', 'dim': 144, 'dim_out': None, 'expansion_factor': 2, 'glu': True, 'dropout_rate': 0.2}}}, 'pixel_embedding': {'_target_': 'text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbeddingImage', 'dim': 144, 'axial_shape': [3, 63], 'axial_dims': [72, 72]}, 'token_pos_embedding': {'_target_': 'text_recognizer.networks.transformer.embeddings.fourier.PositionalEncoding', 'dim': 144, 'dropout_rate': 0.1, 'max_len': 89}}"
+       "{'_target_': 'text_recognizer.networks.ConvTransformer', 'input_dims': [1, 1, 576, 640], 'hidden_dim': 128, 'num_classes': 58, 'pad_index': 3, 'encoder': {'_target_': 'text_recognizer.networks.convnext.ConvNext', 'dim': 16, 'dim_mults': [2, 4, 8], 'depths': [3, 3, 6], 'downsampling_factors': [[2, 2], [2, 2], [2, 2]]}, 'decoder': {'_target_': 'text_recognizer.networks.transformer.Decoder', 'depth': 10, 'block': {'_target_': 'text_recognizer.networks.transformer.DecoderBlock', 'self_attn': {'_target_': 'text_recognizer.networks.transformer.Attention', 'dim': 128, 'num_heads': 12, 'dim_head': 64, 'dropout_rate': 0.2, 'causal': True, 'rotary_embedding': {'_target_': 'text_recognizer.networks.transformer.RotaryEmbedding', 'dim': 64}}, 'cross_attn': {'_target_': 'text_recognizer.networks.transformer.Attention', 'dim': 128, 'num_heads': 12, 'dim_head': 64, 'dropout_rate': 0.2, 'causal': False}, 'norm': {'_target_': 'text_recognizer.networks.transformer.RMSNorm', 'dim': 128}, 'ff': {'_target_': 'text_recognizer.networks.transformer.FeedForward', 'dim': 128, 'dim_out': None, 'expansion_factor': 2, 'glu': True, 'dropout_rate': 0.2}}}, 'pixel_embedding': {'_target_': 'text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbeddingImage', 'dim': 128, 'axial_shape': [7, 128], 'axial_dims': [64, 64]}, 'token_pos_embedding': {'_target_': 'text_recognizer.networks.transformer.embeddings.fourier.PositionalEncoding', 'dim': 128, 'dropout_rate': 0.1, 'max_len': 89}}"
       ]
      },
-     "execution_count": 36,
+     "execution_count": 43,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -82,7 +82,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 37,
+   "execution_count": 44,
    "id": "aaeab329-aeb0-4a1b-aa35-5a2aab81b1d0",
    "metadata": {
     "scrolled": false
@@ -94,7 +94,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 38,
+   "execution_count": 45,
    "id": "618b997c-e6a6-4487-b70c-9d260cb556d3",
    "metadata": {},
    "outputs": [],
@@ -104,17 +104,10 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 39,
+   "execution_count": 46,
    "id": "7daf1f49",
    "metadata": {},
    "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "torch.Size([1, 144, 3, 63])\n"
-     ]
-    },
     {
      "data": {
       "text/plain": [
@@ -122,100 +115,52 @@
        "Layer (type:depth-idx)                                       Output Shape              Param #\n",
        "==============================================================================================================\n",
        "ConvTransformer                                              [1, 58, 89]               --\n",
-       "├─EfficientNet: 1-1                                          [1, 144, 3, 63]           850,880\n",
-       "│    └─Sequential: 2-1                                       [1, 32, 26, 510]          --\n",
-       "│    │    └─ZeroPad2d: 3-1                                   [1, 1, 57, 1025]          --\n",
-       "│    │    └─Conv2d: 3-2                                      [1, 32, 26, 510]          1,568\n",
-       "│    │    └─BatchNorm2d: 3-3                                 [1, 32, 26, 510]          64\n",
-       "│    │    └─Mish: 3-4                                        [1, 32, 26, 510]          --\n",
+       "├─ConvNext: 1-1                                              [1, 128, 7, 128]          1,051,488\n",
+       "│    └─Conv2d: 2-1                                           [1, 16, 56, 1024]         800\n",
        "│    └─ModuleList: 2                                         --                        --\n",
-       "│    │    └─MBConvBlock: 3-5                                 [1, 16, 26, 510]          --\n",
-       "│    │    │    └─Depthwise: 4-1                              [1, 32, 26, 510]          352\n",
-       "│    │    │    └─SqueezeAndExcite: 4-2                       [1, 32, 26, 510]          552\n",
-       "│    │    │    └─Pointwise: 4-3                              [1, 16, 26, 510]          544\n",
-       "│    │    └─MBConvBlock: 3-6                                 [1, 24, 13, 255]          --\n",
-       "│    │    │    └─InvertedBottleneck: 4-4                     [1, 96, 26, 510]          1,728\n",
-       "│    │    │    └─Depthwise: 4-5                              [1, 96, 13, 255]          1,056\n",
-       "│    │    │    └─SqueezeAndExcite: 4-6                       [1, 96, 13, 255]          868\n",
-       "│    │    │    └─Pointwise: 4-7                              [1, 24, 13, 255]          2,352\n",
-       "│    │    └─MBConvBlock: 3-7                                 [1, 24, 13, 255]          --\n",
-       "│    │    │    └─InvertedBottleneck: 4-8                     [1, 144, 13, 255]         3,744\n",
-       "│    │    │    └─Depthwise: 4-9                              [1, 144, 13, 255]         1,584\n",
-       "│    │    │    └─SqueezeAndExcite: 4-10                      [1, 144, 13, 255]         1,878\n",
-       "│    │    │    └─Pointwise: 4-11                             [1, 24, 13, 255]          3,504\n",
-       "│    │    └─MBConvBlock: 3-8                                 [1, 40, 6, 127]           --\n",
-       "│    │    │    └─InvertedBottleneck: 4-12                    [1, 144, 13, 255]         3,744\n",
-       "│    │    │    └─Depthwise: 4-13                             [1, 144, 6, 127]          3,888\n",
-       "│    │    │    └─SqueezeAndExcite: 4-14                      [1, 144, 6, 127]          1,878\n",
-       "│    │    │    └─Pointwise: 4-15                             [1, 40, 6, 127]           5,840\n",
-       "│    │    └─MBConvBlock: 3-9                                 [1, 40, 6, 127]           --\n",
-       "│    │    │    └─InvertedBottleneck: 4-16                    [1, 240, 6, 127]          10,080\n",
-       "│    │    │    └─Depthwise: 4-17                             [1, 240, 6, 127]          6,480\n",
-       "│    │    │    └─SqueezeAndExcite: 4-18                      [1, 240, 6, 127]          5,050\n",
-       "│    │    │    └─Pointwise: 4-19                             [1, 40, 6, 127]           9,680\n",
-       "│    │    └─MBConvBlock: 3-10                                [1, 80, 3, 63]            --\n",
-       "│    │    │    └─InvertedBottleneck: 4-20                    [1, 240, 6, 127]          10,080\n",
-       "│    │    │    └─Depthwise: 4-21                             [1, 240, 3, 63]           2,640\n",
-       "│    │    │    └─SqueezeAndExcite: 4-22                      [1, 240, 3, 63]           5,050\n",
-       "│    │    │    └─Pointwise: 4-23                             [1, 80, 3, 63]            19,360\n",
-       "│    │    └─MBConvBlock: 3-11                                [1, 80, 3, 63]            --\n",
-       "│    │    │    └─InvertedBottleneck: 4-24                    [1, 480, 3, 63]           39,360\n",
-       "│    │    │    └─Depthwise: 4-25                             [1, 480, 3, 63]           5,280\n",
-       "│    │    │    └─SqueezeAndExcite: 4-26                      [1, 480, 3, 63]           19,700\n",
-       "│    │    │    └─Pointwise: 4-27                             [1, 80, 3, 63]            38,560\n",
-       "│    │    └─MBConvBlock: 3-12                                [1, 80, 3, 63]            --\n",
-       "│    │    │    └─InvertedBottleneck: 4-28                    [1, 480, 3, 63]           39,360\n",
-       "│    │    │    └─Depthwise: 4-29                             [1, 480, 3, 63]           5,280\n",
-       "│    │    │    └─SqueezeAndExcite: 4-30                      [1, 480, 3, 63]           19,700\n",
-       "│    │    │    └─Pointwise: 4-31                             [1, 80, 3, 63]            38,560\n",
-       "│    │    └─MBConvBlock: 3-13                                [1, 112, 3, 63]           --\n",
-       "│    │    │    └─InvertedBottleneck: 4-32                    [1, 480, 3, 63]           39,360\n",
-       "│    │    │    └─Depthwise: 4-33                             [1, 480, 3, 63]           12,960\n",
-       "│    │    │    └─SqueezeAndExcite: 4-34                      [1, 480, 3, 63]           19,700\n",
-       "│    │    │    └─Pointwise: 4-35                             [1, 112, 3, 63]           53,984\n",
-       "│    │    └─MBConvBlock: 3-14                                [1, 112, 3, 63]           --\n",
-       "│    │    │    └─InvertedBottleneck: 4-36                    [1, 672, 3, 63]           76,608\n",
-       "│    │    │    └─Depthwise: 4-37                             [1, 672, 3, 63]           18,144\n",
-       "│    │    │    └─SqueezeAndExcite: 4-38                      [1, 672, 3, 63]           38,332\n",
-       "│    │    │    └─Pointwise: 4-39                             [1, 112, 3, 63]           75,488\n",
-       "│    │    └─MBConvBlock: 3-15                                [1, 112, 3, 63]           --\n",
-       "│    │    │    └─InvertedBottleneck: 4-40                    [1, 672, 3, 63]           76,608\n",
-       "│    │    │    └─Depthwise: 4-41                             [1, 672, 3, 63]           18,144\n",
-       "│    │    │    └─SqueezeAndExcite: 4-42                      [1, 672, 3, 63]           38,332\n",
-       "│    │    │    └─Pointwise: 4-43                             [1, 112, 3, 63]           75,488\n",
-       "│    └─Sequential: 2-2                                       [1, 144, 3, 63]           --\n",
-       "│    │    └─Conv2d: 3-16                                     [1, 144, 3, 63]           16,128\n",
-       "│    │    └─BatchNorm2d: 3-17                                [1, 144, 3, 63]           288\n",
-       "│    │    └─Dropout: 3-18                                    [1, 144, 3, 63]           --\n",
-       "├─Conv2d: 1-2                                                [1, 144, 3, 63]           20,880\n",
-       "├─AxialPositionalEmbeddingImage: 1-3                         [1, 144, 3, 63]           --\n",
-       "│    └─AxialPositionalEmbedding: 2-3                         [1, 189, 144]             4,752\n",
-       "├─Embedding: 1-4                                             [1, 89, 144]              8,352\n",
-       "├─PositionalEncoding: 1-5                                    [1, 89, 144]              --\n",
-       "│    └─Dropout: 2-4                                          [1, 89, 144]              --\n",
-       "├─Decoder: 1-6                                               [1, 89, 144]              --\n",
+       "│    │    └─ModuleList: 3                                    --                        --\n",
+       "│    │    │    └─ConvNextBlock: 4-1                          [1, 16, 56, 1024]         10,080\n",
+       "│    │    │    └─Downsample: 4-2                             [1, 32, 28, 512]          2,080\n",
+       "│    │    └─ModuleList: 3                                    --                        --\n",
+       "│    │    │    └─ConvNextBlock: 4-3                          [1, 32, 28, 512]          38,592\n",
+       "│    │    │    └─Downsample: 4-4                             [1, 64, 14, 256]          8,256\n",
+       "│    │    └─ModuleList: 3                                    --                        --\n",
+       "│    │    │    └─ConvNextBlock: 4-5                          [1, 64, 14, 256]          150,912\n",
+       "│    │    │    └─Downsample: 4-6                             [1, 128, 7, 128]          32,896\n",
+       "│    └─LayerNorm: 2-2                                        [1, 128, 7, 128]          128\n",
+       "├─Conv2d: 1-2                                                [1, 128, 7, 128]          16,512\n",
+       "├─AxialPositionalEmbeddingImage: 1-3                         [1, 128, 7, 128]          --\n",
+       "│    └─AxialPositionalEmbedding: 2-3                         [1, 896, 128]             8,640\n",
+       "├─Embedding: 1-4                                             [1, 89, 128]              7,424\n",
+       "├─PositionalEncoding: 1-5                                    [1, 89, 128]              --\n",
+       "│    └─Dropout: 2-4                                          [1, 89, 128]              --\n",
+       "├─Decoder: 1-6                                               [1, 89, 128]              --\n",
        "│    └─ModuleList: 2                                         --                        --\n",
-       "│    │    └─DecoderBlock: 3-19                               [1, 89, 144]              --\n",
-       "│    │    └─DecoderBlock: 3-20                               [1, 89, 144]              --\n",
-       "│    │    └─DecoderBlock: 3-21                               [1, 89, 144]              --\n",
-       "│    │    └─DecoderBlock: 3-22                               [1, 89, 144]              --\n",
-       "│    │    └─DecoderBlock: 3-23                               [1, 89, 144]              --\n",
-       "│    │    └─DecoderBlock: 3-24                               [1, 89, 144]              --\n",
-       "├─Linear: 1-7                                                [1, 89, 58]               8,410\n",
+       "│    │    └─DecoderBlock: 3-1                                [1, 89, 128]              --\n",
+       "│    │    └─DecoderBlock: 3-2                                [1, 89, 128]              --\n",
+       "│    │    └─DecoderBlock: 3-3                                [1, 89, 128]              --\n",
+       "│    │    └─DecoderBlock: 3-4                                [1, 89, 128]              --\n",
+       "│    │    └─DecoderBlock: 3-5                                [1, 89, 128]              --\n",
+       "│    │    └─DecoderBlock: 3-6                                [1, 89, 128]              --\n",
+       "│    │    └─DecoderBlock: 3-7                                [1, 89, 128]              --\n",
+       "│    │    └─DecoderBlock: 3-8                                [1, 89, 128]              --\n",
+       "│    │    └─DecoderBlock: 3-9                                [1, 89, 128]              --\n",
+       "│    │    └─DecoderBlock: 3-10                               [1, 89, 128]              --\n",
+       "├─Linear: 1-7                                                [1, 89, 58]               7,482\n",
        "==============================================================================================================\n",
-       "Total params: 6,090,138\n",
-       "Trainable params: 6,090,138\n",
+       "Total params: 10,195,450\n",
+       "Trainable params: 10,195,450\n",
        "Non-trainable params: 0\n",
-       "Total mult-adds (M): 313.64\n",
+       "Total mult-adds (G): 8.47\n",
        "==============================================================================================================\n",
        "Input size (MB): 0.23\n",
-       "Forward/backward pass size (MB): 145.27\n",
-       "Params size (MB): 24.36\n",
-       "Estimated Total Size (MB): 169.86\n",
+       "Forward/backward pass size (MB): 442.16\n",
+       "Params size (MB): 40.78\n",
+       "Estimated Total Size (MB): 483.17\n",
        "=============================================================================================================="
       ]
      },
-     "execution_count": 39,
+     "execution_count": 46,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -229,7 +174,7 @@
    "execution_count": 22,
    "id": "25759b7b-8deb-4163-b75d-a1357c9fe88f",
    "metadata": {
-    "scrolled": false
+    "scrolled": true
    },
    "outputs": [
     {
-- 
cgit v1.2.3-70-g09d2