From 31e127c479cac61134bed3d5c4341561eef68c52 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 13 Sep 2022 19:07:06 +0200 Subject: Update conv transformer notebook --- notebooks/04-conv-transformer.ipynb | 147 +++++++++++------------------------- 1 file changed, 46 insertions(+), 101 deletions(-) (limited to 'notebooks/04-conv-transformer.ipynb') diff --git a/notebooks/04-conv-transformer.ipynb b/notebooks/04-conv-transformer.ipynb index 8ded6b6..50779a9 100644 --- a/notebooks/04-conv-transformer.ipynb +++ b/notebooks/04-conv-transformer.ipynb @@ -40,7 +40,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 26, "id": "3cf50475-39f2-4642-a7d1-5bcbc0a036f7", "metadata": {}, "outputs": [], @@ -50,7 +50,7 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 42, "id": "e52ecb01-c975-4e55-925d-1182c7aea473", "metadata": {}, "outputs": [], @@ -61,17 +61,17 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 43, "id": "f939aa37-7b1d-45cc-885c-323c4540bda1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'_target_': 'text_recognizer.networks.ConvTransformer', 'input_dims': [1, 1, 576, 640], 'hidden_dim': 144, 'num_classes': 58, 'pad_index': 3, 'encoder': {'_target_': 'text_recognizer.networks.EfficientNet', 'arch': 'b0', 'stochastic_dropout_rate': 0.2, 'bn_momentum': 0.99, 'bn_eps': 0.001, 'depth': 5, 'out_channels': 144}, 'decoder': {'_target_': 'text_recognizer.networks.transformer.Decoder', 'depth': 6, 'block': {'_target_': 'text_recognizer.networks.transformer.DecoderBlock', 'self_attn': {'_target_': 'text_recognizer.networks.transformer.Attention', 'dim': 144, 'num_heads': 12, 'dim_head': 64, 'dropout_rate': 0.2, 'causal': True, 'rotary_embedding': {'_target_': 'text_recognizer.networks.transformer.RotaryEmbedding', 'dim': 64}}, 'cross_attn': {'_target_': 'text_recognizer.networks.transformer.Attention', 'dim': 144, 'num_heads': 8, 'dim_head': 64, 'dropout_rate': 0.2, 'causal': False}, 'norm': {'_target_': 'text_recognizer.networks.transformer.RMSNorm', 'dim': 144}, 'ff': {'_target_': 'text_recognizer.networks.transformer.FeedForward', 'dim': 144, 'dim_out': None, 'expansion_factor': 2, 'glu': True, 'dropout_rate': 0.2}}}, 'pixel_embedding': {'_target_': 'text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbeddingImage', 'dim': 144, 'axial_shape': [3, 63], 'axial_dims': [72, 72]}, 'token_pos_embedding': {'_target_': 'text_recognizer.networks.transformer.embeddings.fourier.PositionalEncoding', 'dim': 144, 'dropout_rate': 0.1, 'max_len': 89}}" + "{'_target_': 'text_recognizer.networks.ConvTransformer', 'input_dims': [1, 1, 576, 640], 'hidden_dim': 128, 'num_classes': 58, 'pad_index': 3, 'encoder': {'_target_': 'text_recognizer.networks.convnext.ConvNext', 'dim': 16, 'dim_mults': [2, 4, 8], 'depths': [3, 3, 6], 'downsampling_factors': [[2, 2], [2, 2], [2, 2]]}, 'decoder': {'_target_': 'text_recognizer.networks.transformer.Decoder', 'depth': 10, 'block': {'_target_': 'text_recognizer.networks.transformer.DecoderBlock', 'self_attn': {'_target_': 'text_recognizer.networks.transformer.Attention', 'dim': 128, 'num_heads': 12, 'dim_head': 64, 'dropout_rate': 0.2, 'causal': True, 'rotary_embedding': {'_target_': 'text_recognizer.networks.transformer.RotaryEmbedding', 'dim': 64}}, 'cross_attn': {'_target_': 'text_recognizer.networks.transformer.Attention', 'dim': 128, 'num_heads': 12, 'dim_head': 64, 'dropout_rate': 0.2, 'causal': False}, 'norm': {'_target_': 'text_recognizer.networks.transformer.RMSNorm', 'dim': 128}, 'ff': {'_target_': 'text_recognizer.networks.transformer.FeedForward', 'dim': 128, 'dim_out': None, 'expansion_factor': 2, 'glu': True, 'dropout_rate': 0.2}}}, 'pixel_embedding': {'_target_': 'text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbeddingImage', 'dim': 128, 'axial_shape': [7, 128], 'axial_dims': [64, 64]}, 'token_pos_embedding': {'_target_': 'text_recognizer.networks.transformer.embeddings.fourier.PositionalEncoding', 'dim': 128, 'dropout_rate': 0.1, 'max_len': 89}}" ] }, - "execution_count": 36, + "execution_count": 43, "metadata": {}, "output_type": "execute_result" } @@ -82,7 +82,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 44, "id": "aaeab329-aeb0-4a1b-aa35-5a2aab81b1d0", "metadata": { "scrolled": false @@ -94,7 +94,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 45, "id": "618b997c-e6a6-4487-b70c-9d260cb556d3", "metadata": {}, "outputs": [], @@ -104,17 +104,10 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 46, "id": "7daf1f49", "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([1, 144, 3, 63])\n" - ] - }, { "data": { "text/plain": [ @@ -122,100 +115,52 @@ "Layer (type:depth-idx) Output Shape Param #\n", "==============================================================================================================\n", "ConvTransformer [1, 58, 89] --\n", - "├─EfficientNet: 1-1 [1, 144, 3, 63] 850,880\n", - "│ └─Sequential: 2-1 [1, 32, 26, 510] --\n", - "│ │ └─ZeroPad2d: 3-1 [1, 1, 57, 1025] --\n", - "│ │ └─Conv2d: 3-2 [1, 32, 26, 510] 1,568\n", - "│ │ └─BatchNorm2d: 3-3 [1, 32, 26, 510] 64\n", - "│ │ └─Mish: 3-4 [1, 32, 26, 510] --\n", + "├─ConvNext: 1-1 [1, 128, 7, 128] 1,051,488\n", + "│ └─Conv2d: 2-1 [1, 16, 56, 1024] 800\n", "│ └─ModuleList: 2 -- --\n", - "│ │ └─MBConvBlock: 3-5 [1, 16, 26, 510] --\n", - "│ │ │ └─Depthwise: 4-1 [1, 32, 26, 510] 352\n", - "│ │ │ └─SqueezeAndExcite: 4-2 [1, 32, 26, 510] 552\n", - "│ │ │ └─Pointwise: 4-3 [1, 16, 26, 510] 544\n", - "│ │ └─MBConvBlock: 3-6 [1, 24, 13, 255] --\n", - "│ │ │ └─InvertedBottleneck: 4-4 [1, 96, 26, 510] 1,728\n", - "│ │ │ └─Depthwise: 4-5 [1, 96, 13, 255] 1,056\n", - "│ │ │ └─SqueezeAndExcite: 4-6 [1, 96, 13, 255] 868\n", - "│ │ │ └─Pointwise: 4-7 [1, 24, 13, 255] 2,352\n", - "│ │ └─MBConvBlock: 3-7 [1, 24, 13, 255] --\n", - "│ │ │ └─InvertedBottleneck: 4-8 [1, 144, 13, 255] 3,744\n", - "│ │ │ └─Depthwise: 4-9 [1, 144, 13, 255] 1,584\n", - "│ │ │ └─SqueezeAndExcite: 4-10 [1, 144, 13, 255] 1,878\n", - "│ │ │ └─Pointwise: 4-11 [1, 24, 13, 255] 3,504\n", - "│ │ └─MBConvBlock: 3-8 [1, 40, 6, 127] --\n", - "│ │ │ └─InvertedBottleneck: 4-12 [1, 144, 13, 255] 3,744\n", - "│ │ │ └─Depthwise: 4-13 [1, 144, 6, 127] 3,888\n", - "│ │ │ └─SqueezeAndExcite: 4-14 [1, 144, 6, 127] 1,878\n", - "│ │ │ └─Pointwise: 4-15 [1, 40, 6, 127] 5,840\n", - "│ │ └─MBConvBlock: 3-9 [1, 40, 6, 127] --\n", - "│ │ │ └─InvertedBottleneck: 4-16 [1, 240, 6, 127] 10,080\n", - "│ │ │ └─Depthwise: 4-17 [1, 240, 6, 127] 6,480\n", - "│ │ │ └─SqueezeAndExcite: 4-18 [1, 240, 6, 127] 5,050\n", - "│ │ │ └─Pointwise: 4-19 [1, 40, 6, 127] 9,680\n", - "│ │ └─MBConvBlock: 3-10 [1, 80, 3, 63] --\n", - "│ │ │ └─InvertedBottleneck: 4-20 [1, 240, 6, 127] 10,080\n", - "│ │ │ └─Depthwise: 4-21 [1, 240, 3, 63] 2,640\n", - "│ │ │ └─SqueezeAndExcite: 4-22 [1, 240, 3, 63] 5,050\n", - "│ │ │ └─Pointwise: 4-23 [1, 80, 3, 63] 19,360\n", - "│ │ └─MBConvBlock: 3-11 [1, 80, 3, 63] --\n", - "│ │ │ └─InvertedBottleneck: 4-24 [1, 480, 3, 63] 39,360\n", - "│ │ │ └─Depthwise: 4-25 [1, 480, 3, 63] 5,280\n", - "│ │ │ └─SqueezeAndExcite: 4-26 [1, 480, 3, 63] 19,700\n", - "│ │ │ └─Pointwise: 4-27 [1, 80, 3, 63] 38,560\n", - "│ │ └─MBConvBlock: 3-12 [1, 80, 3, 63] --\n", - "│ │ │ └─InvertedBottleneck: 4-28 [1, 480, 3, 63] 39,360\n", - "│ │ │ └─Depthwise: 4-29 [1, 480, 3, 63] 5,280\n", - "│ │ │ └─SqueezeAndExcite: 4-30 [1, 480, 3, 63] 19,700\n", - "│ │ │ └─Pointwise: 4-31 [1, 80, 3, 63] 38,560\n", - "│ │ └─MBConvBlock: 3-13 [1, 112, 3, 63] --\n", - "│ │ │ └─InvertedBottleneck: 4-32 [1, 480, 3, 63] 39,360\n", - "│ │ │ └─Depthwise: 4-33 [1, 480, 3, 63] 12,960\n", - "│ │ │ └─SqueezeAndExcite: 4-34 [1, 480, 3, 63] 19,700\n", - "│ │ │ └─Pointwise: 4-35 [1, 112, 3, 63] 53,984\n", - "│ │ └─MBConvBlock: 3-14 [1, 112, 3, 63] --\n", - "│ │ │ └─InvertedBottleneck: 4-36 [1, 672, 3, 63] 76,608\n", - "│ │ │ └─Depthwise: 4-37 [1, 672, 3, 63] 18,144\n", - "│ │ │ └─SqueezeAndExcite: 4-38 [1, 672, 3, 63] 38,332\n", - "│ │ │ └─Pointwise: 4-39 [1, 112, 3, 63] 75,488\n", - "│ │ └─MBConvBlock: 3-15 [1, 112, 3, 63] --\n", - "│ │ │ └─InvertedBottleneck: 4-40 [1, 672, 3, 63] 76,608\n", - "│ │ │ └─Depthwise: 4-41 [1, 672, 3, 63] 18,144\n", - "│ │ │ └─SqueezeAndExcite: 4-42 [1, 672, 3, 63] 38,332\n", - "│ │ │ └─Pointwise: 4-43 [1, 112, 3, 63] 75,488\n", - "│ └─Sequential: 2-2 [1, 144, 3, 63] --\n", - "│ │ └─Conv2d: 3-16 [1, 144, 3, 63] 16,128\n", - "│ │ └─BatchNorm2d: 3-17 [1, 144, 3, 63] 288\n", - "│ │ └─Dropout: 3-18 [1, 144, 3, 63] --\n", - "├─Conv2d: 1-2 [1, 144, 3, 63] 20,880\n", - "├─AxialPositionalEmbeddingImage: 1-3 [1, 144, 3, 63] --\n", - "│ └─AxialPositionalEmbedding: 2-3 [1, 189, 144] 4,752\n", - "├─Embedding: 1-4 [1, 89, 144] 8,352\n", - "├─PositionalEncoding: 1-5 [1, 89, 144] --\n", - "│ └─Dropout: 2-4 [1, 89, 144] --\n", - "├─Decoder: 1-6 [1, 89, 144] --\n", + "│ │ └─ModuleList: 3 -- --\n", + "│ │ │ └─ConvNextBlock: 4-1 [1, 16, 56, 1024] 10,080\n", + "│ │ │ └─Downsample: 4-2 [1, 32, 28, 512] 2,080\n", + "│ │ └─ModuleList: 3 -- --\n", + "│ │ │ └─ConvNextBlock: 4-3 [1, 32, 28, 512] 38,592\n", + "│ │ │ └─Downsample: 4-4 [1, 64, 14, 256] 8,256\n", + "│ │ └─ModuleList: 3 -- --\n", + "│ │ │ └─ConvNextBlock: 4-5 [1, 64, 14, 256] 150,912\n", + "│ │ │ └─Downsample: 4-6 [1, 128, 7, 128] 32,896\n", + "│ └─LayerNorm: 2-2 [1, 128, 7, 128] 128\n", + "├─Conv2d: 1-2 [1, 128, 7, 128] 16,512\n", + "├─AxialPositionalEmbeddingImage: 1-3 [1, 128, 7, 128] --\n", + "│ └─AxialPositionalEmbedding: 2-3 [1, 896, 128] 8,640\n", + "├─Embedding: 1-4 [1, 89, 128] 7,424\n", + "├─PositionalEncoding: 1-5 [1, 89, 128] --\n", + "│ └─Dropout: 2-4 [1, 89, 128] --\n", + "├─Decoder: 1-6 [1, 89, 128] --\n", "│ └─ModuleList: 2 -- --\n", - "│ │ └─DecoderBlock: 3-19 [1, 89, 144] --\n", - "│ │ └─DecoderBlock: 3-20 [1, 89, 144] --\n", - "│ │ └─DecoderBlock: 3-21 [1, 89, 144] --\n", - "│ │ └─DecoderBlock: 3-22 [1, 89, 144] --\n", - "│ │ └─DecoderBlock: 3-23 [1, 89, 144] --\n", - "│ │ └─DecoderBlock: 3-24 [1, 89, 144] --\n", - "├─Linear: 1-7 [1, 89, 58] 8,410\n", + "│ │ └─DecoderBlock: 3-1 [1, 89, 128] --\n", + "│ │ └─DecoderBlock: 3-2 [1, 89, 128] --\n", + "│ │ └─DecoderBlock: 3-3 [1, 89, 128] --\n", + "│ │ └─DecoderBlock: 3-4 [1, 89, 128] --\n", + "│ │ └─DecoderBlock: 3-5 [1, 89, 128] --\n", + "│ │ └─DecoderBlock: 3-6 [1, 89, 128] --\n", + "│ │ └─DecoderBlock: 3-7 [1, 89, 128] --\n", + "│ │ └─DecoderBlock: 3-8 [1, 89, 128] --\n", + "│ │ └─DecoderBlock: 3-9 [1, 89, 128] --\n", + "│ │ └─DecoderBlock: 3-10 [1, 89, 128] --\n", + "├─Linear: 1-7 [1, 89, 58] 7,482\n", "==============================================================================================================\n", - "Total params: 6,090,138\n", - "Trainable params: 6,090,138\n", + "Total params: 10,195,450\n", + "Trainable params: 10,195,450\n", "Non-trainable params: 0\n", - "Total mult-adds (M): 313.64\n", + "Total mult-adds (G): 8.47\n", "==============================================================================================================\n", "Input size (MB): 0.23\n", - "Forward/backward pass size (MB): 145.27\n", - "Params size (MB): 24.36\n", - "Estimated Total Size (MB): 169.86\n", + "Forward/backward pass size (MB): 442.16\n", + "Params size (MB): 40.78\n", + "Estimated Total Size (MB): 483.17\n", "==============================================================================================================" ] }, - "execution_count": 39, + "execution_count": 46, "metadata": {}, "output_type": "execute_result" } @@ -229,7 +174,7 @@ "execution_count": 22, "id": "25759b7b-8deb-4163-b75d-a1357c9fe88f", "metadata": { - "scrolled": false + "scrolled": true }, "outputs": [ { -- cgit v1.2.3-70-g09d2