summaryrefslogtreecommitdiff
path: root/notebooks/04-conv-transformer.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'notebooks/04-conv-transformer.ipynb')
-rw-r--r--notebooks/04-conv-transformer.ipynb180
1 files changed, 53 insertions, 127 deletions
diff --git a/notebooks/04-conv-transformer.ipynb b/notebooks/04-conv-transformer.ipynb
index b864098..0d8b370 100644
--- a/notebooks/04-conv-transformer.ipynb
+++ b/notebooks/04-conv-transformer.ipynb
@@ -2,19 +2,10 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 1,
"id": "7c02ae76-b540-4b16-9492-e9210b3b9249",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "The autoreload extension is already loaded. To reload it, use:\n",
- " %reload_ext autoreload\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"import os\n",
"os.environ['CUDA_VISIBLE_DEVICE'] = ''\n",
@@ -49,7 +40,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 5,
"id": "3cf50475-39f2-4642-a7d1-5bcbc0a036f7",
"metadata": {},
"outputs": [],
@@ -59,7 +50,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 6,
"id": "e52ecb01-c975-4e55-925d-1182c7aea473",
"metadata": {},
"outputs": [],
@@ -70,17 +61,17 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 7,
"id": "f939aa37-7b1d-45cc-885c-323c4540bda1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "{'_target_': 'text_recognizer.networks.ConvTransformer', 'input_dims': [1, 1, 576, 640], 'hidden_dim': 128, 'num_classes': 58, 'pad_index': 3, 'encoder': {'_target_': 'text_recognizer.networks.convnext.ConvNext', 'dim': 16, 'dim_mults': [2, 4, 8], 'depths': [3, 3, 6], 'downsampling_factors': [[2, 2], [2, 2], [2, 2]]}, 'decoder': {'_target_': 'text_recognizer.networks.transformer.Decoder', 'dim': 128, 'depth': 10, 'block': {'_target_': 'text_recognizer.networks.transformer.decoder_block.DecoderBlock', 'self_attn': {'_target_': 'text_recognizer.networks.transformer.Attention', 'dim': 128, 'num_heads': 12, 'dim_head': 64, 'dropout_rate': 0.2, 'causal': True, 'rotary_embedding': {'_target_': 'text_recognizer.networks.transformer.RotaryEmbedding', 'dim': 64}}, 'cross_attn': {'_target_': 'text_recognizer.networks.transformer.Attention', 'dim': 128, 'num_heads': 12, 'dim_head': 64, 'dropout_rate': 0.2, 'causal': False}, 'norm': {'_target_': 'text_recognizer.networks.transformer.RMSNorm', 'dim': 128}, 'ff': {'_target_': 'text_recognizer.networks.transformer.FeedForward', 'dim': 128, 'dim_out': None, 'expansion_factor': 2, 'glu': True, 'dropout_rate': 0.2}}}, 'pixel_embedding': {'_target_': 'text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbeddingImage', 'dim': 128, 'axial_shape': [7, 128], 'axial_dims': [64, 64]}, 'token_pos_embedding': {'_target_': 'text_recognizer.networks.transformer.embeddings.fourier.PositionalEncoding', 'dim': 128, 'dropout_rate': 0.1, 'max_len': 89}}"
+ "{'_target_': 'text_recognizer.networks.ConvTransformer', 'encoder': {'_target_': 'text_recognizer.networks.image_encoder.ImageEncoder', 'encoder': {'_target_': 'text_recognizer.networks.convnext.ConvNext', 'dim': 16, 'dim_mults': [2, 4, 8], 'depths': [3, 3, 6], 'downsampling_factors': [[2, 2], [2, 2], [2, 2]]}, 'pixel_embedding': {'_target_': 'text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbeddingImage', 'dim': 128, 'axial_shape': [7, 128], 'axial_dims': [64, 64]}}, 'decoder': {'_target_': 'text_recognizer.networks.text_decoder.TextDecoder', 'hidden_dim': 128, 'num_classes': 58, 'pad_index': 3, 'decoder': {'_target_': 'text_recognizer.networks.transformer.Decoder', 'dim': 128, 'depth': 10, 'block': {'_target_': 'text_recognizer.networks.transformer.decoder_block.DecoderBlock', 'self_attn': {'_target_': 'text_recognizer.networks.transformer.Attention', 'dim': 128, 'num_heads': 12, 'dim_head': 64, 'dropout_rate': 0.2, 'causal': True}, 'cross_attn': {'_target_': 'text_recognizer.networks.transformer.Attention', 'dim': 128, 'num_heads': 12, 'dim_head': 64, 'dropout_rate': 0.2, 'causal': False}, 'norm': {'_target_': 'text_recognizer.networks.transformer.RMSNorm', 'dim': 128}, 'ff': {'_target_': 'text_recognizer.networks.transformer.FeedForward', 'dim': 128, 'dim_out': None, 'expansion_factor': 2, 'glu': True, 'dropout_rate': 0.2}}, 'rotary_embedding': {'_target_': 'text_recognizer.networks.transformer.RotaryEmbedding', 'dim': 64}}, 'token_pos_embedding': {'_target_': 'text_recognizer.networks.transformer.embeddings.fourier.PositionalEncoding', 'dim': 128, 'dropout_rate': 0.1, 'max_len': 89}}}"
]
},
- "execution_count": 5,
+ "execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@@ -91,7 +82,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 13,
"id": "aaeab329-aeb0-4a1b-aa35-5a2aab81b1d0",
"metadata": {
"scrolled": false
@@ -103,7 +94,7 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 14,
"id": "618b997c-e6a6-4487-b70c-9d260cb556d3",
"metadata": {},
"outputs": [],
@@ -113,125 +104,60 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 15,
"id": "7daf1f49",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "=========================================================================================================\n",
- "Layer (type:depth-idx) Output Shape Param #\n",
- "=========================================================================================================\n",
- "ConvTransformer [1, 58, 89] --\n",
- "├─ConvNext: 1-1 [1, 128, 7, 128] 1,051,488\n",
- "│ └─Conv2d: 2-1 [1, 16, 56, 1024] 800\n",
- "│ └─ModuleList: 2 -- --\n",
- "│ │ └─ModuleList: 3 -- --\n",
- "│ │ │ └─ConvNextBlock: 4-1 [1, 16, 56, 1024] 10,080\n",
- "│ │ │ └─Downsample: 4-2 [1, 32, 28, 512] 2,080\n",
- "│ │ └─ModuleList: 3 -- --\n",
- "│ │ │ └─ConvNextBlock: 4-3 [1, 32, 28, 512] 38,592\n",
- "│ │ │ └─Downsample: 4-4 [1, 64, 14, 256] 8,256\n",
- "│ │ └─ModuleList: 3 -- --\n",
- "│ │ │ └─ConvNextBlock: 4-5 [1, 64, 14, 256] 150,912\n",
- "│ │ │ └─Downsample: 4-6 [1, 128, 7, 128] 32,896\n",
- "│ └─Identity: 2-2 [1, 128, 7, 128] --\n",
- "│ └─LayerNorm: 2-3 [1, 128, 7, 128] 128\n",
- "├─Conv2d: 1-2 [1, 128, 7, 128] 16,512\n",
- "├─AxialPositionalEmbeddingImage: 1-3 [1, 128, 7, 128] --\n",
- "│ └─AxialPositionalEmbedding: 2-4 [1, 896, 128] 8,640\n",
- "├─Embedding: 1-4 [1, 89, 128] 7,424\n",
- "├─PositionalEncoding: 1-5 [1, 89, 128] --\n",
- "│ └─Dropout: 2-5 [1, 89, 128] --\n",
- "├─Decoder: 1-6 [1, 89, 128] --\n",
- "│ └─ModuleList: 2 -- --\n",
- "│ │ └─DecoderBlock: 3-1 [1, 89, 128] --\n",
- "│ │ │ └─RMSNorm: 4-7 [1, 89, 128] 128\n",
- "│ │ │ └─Attention: 4-8 [1, 89, 128] 393,344\n",
- "│ │ │ └─RMSNorm: 4-9 [1, 89, 128] 128\n",
- "│ │ │ └─Attention: 4-10 [1, 89, 128] 393,344\n",
- "│ │ │ └─RMSNorm: 4-11 [1, 89, 128] 128\n",
- "│ │ │ └─FeedForward: 4-12 [1, 89, 128] 98,944\n",
- "│ │ └─DecoderBlock: 3-2 [1, 89, 128] --\n",
- "│ │ │ └─RMSNorm: 4-13 [1, 89, 128] 128\n",
- "│ │ │ └─Attention: 4-14 [1, 89, 128] 393,344\n",
- "│ │ │ └─RMSNorm: 4-15 [1, 89, 128] 128\n",
- "│ │ │ └─Attention: 4-16 [1, 89, 128] 393,344\n",
- "│ │ │ └─RMSNorm: 4-17 [1, 89, 128] 128\n",
- "│ │ │ └─FeedForward: 4-18 [1, 89, 128] 98,944\n",
- "│ │ └─DecoderBlock: 3-3 [1, 89, 128] --\n",
- "│ │ │ └─RMSNorm: 4-19 [1, 89, 128] 128\n",
- "│ │ │ └─Attention: 4-20 [1, 89, 128] 393,344\n",
- "│ │ │ └─RMSNorm: 4-21 [1, 89, 128] 128\n",
- "│ │ │ └─Attention: 4-22 [1, 89, 128] 393,344\n",
- "│ │ │ └─RMSNorm: 4-23 [1, 89, 128] 128\n",
- "│ │ │ └─FeedForward: 4-24 [1, 89, 128] 98,944\n",
- "│ │ └─DecoderBlock: 3-4 [1, 89, 128] --\n",
- "│ │ │ └─RMSNorm: 4-25 [1, 89, 128] 128\n",
- "│ │ │ └─Attention: 4-26 [1, 89, 128] 393,344\n",
- "│ │ │ └─RMSNorm: 4-27 [1, 89, 128] 128\n",
- "│ │ │ └─Attention: 4-28 [1, 89, 128] 393,344\n",
- "│ │ │ └─RMSNorm: 4-29 [1, 89, 128] 128\n",
- "│ │ │ └─FeedForward: 4-30 [1, 89, 128] 98,944\n",
- "│ │ └─DecoderBlock: 3-5 [1, 89, 128] --\n",
- "│ │ │ └─RMSNorm: 4-31 [1, 89, 128] 128\n",
- "│ │ │ └─Attention: 4-32 [1, 89, 128] 393,344\n",
- "│ │ │ └─RMSNorm: 4-33 [1, 89, 128] 128\n",
- "│ │ │ └─Attention: 4-34 [1, 89, 128] 393,344\n",
- "│ │ │ └─RMSNorm: 4-35 [1, 89, 128] 128\n",
- "│ │ │ └─FeedForward: 4-36 [1, 89, 128] 98,944\n",
- "│ │ └─DecoderBlock: 3-6 [1, 89, 128] --\n",
- "│ │ │ └─RMSNorm: 4-37 [1, 89, 128] 128\n",
- "│ │ │ └─Attention: 4-38 [1, 89, 128] 393,344\n",
- "│ │ │ └─RMSNorm: 4-39 [1, 89, 128] 128\n",
- "│ │ │ └─Attention: 4-40 [1, 89, 128] 393,344\n",
- "│ │ │ └─RMSNorm: 4-41 [1, 89, 128] 128\n",
- "│ │ │ └─FeedForward: 4-42 [1, 89, 128] 98,944\n",
- "│ │ └─DecoderBlock: 3-7 [1, 89, 128] --\n",
- "│ │ │ └─RMSNorm: 4-43 [1, 89, 128] 128\n",
- "│ │ │ └─Attention: 4-44 [1, 89, 128] 393,344\n",
- "│ │ │ └─RMSNorm: 4-45 [1, 89, 128] 128\n",
- "│ │ │ └─Attention: 4-46 [1, 89, 128] 393,344\n",
- "│ │ │ └─RMSNorm: 4-47 [1, 89, 128] 128\n",
- "│ │ │ └─FeedForward: 4-48 [1, 89, 128] 98,944\n",
- "│ │ └─DecoderBlock: 3-8 [1, 89, 128] --\n",
- "│ │ │ └─RMSNorm: 4-49 [1, 89, 128] 128\n",
- "│ │ │ └─Attention: 4-50 [1, 89, 128] 393,344\n",
- "│ │ │ └─RMSNorm: 4-51 [1, 89, 128] 128\n",
- "│ │ │ └─Attention: 4-52 [1, 89, 128] 393,344\n",
- "│ │ │ └─RMSNorm: 4-53 [1, 89, 128] 128\n",
- "│ │ │ └─FeedForward: 4-54 [1, 89, 128] 98,944\n",
- "│ │ └─DecoderBlock: 3-9 [1, 89, 128] --\n",
- "│ │ │ └─RMSNorm: 4-55 [1, 89, 128] 128\n",
- "│ │ │ └─Attention: 4-56 [1, 89, 128] 393,344\n",
- "│ │ │ └─RMSNorm: 4-57 [1, 89, 128] 128\n",
- "│ │ │ └─Attention: 4-58 [1, 89, 128] 393,344\n",
- "│ │ │ └─RMSNorm: 4-59 [1, 89, 128] 128\n",
- "│ │ │ └─FeedForward: 4-60 [1, 89, 128] 98,944\n",
- "│ │ └─DecoderBlock: 3-10 [1, 89, 128] --\n",
- "│ │ │ └─RMSNorm: 4-61 [1, 89, 128] 128\n",
- "│ │ │ └─Attention: 4-62 [1, 89, 128] 393,344\n",
- "│ │ │ └─RMSNorm: 4-63 [1, 89, 128] 128\n",
- "│ │ │ └─Attention: 4-64 [1, 89, 128] 393,344\n",
- "│ │ │ └─RMSNorm: 4-65 [1, 89, 128] 128\n",
- "│ │ │ └─FeedForward: 4-66 [1, 89, 128] 98,944\n",
- "│ └─LayerNorm: 2-6 [1, 89, 128] 256\n",
- "├─Linear: 1-7 [1, 89, 58] 7,482\n",
- "=========================================================================================================\n",
- "Total params: 10,195,706\n",
- "Trainable params: 10,195,706\n",
+ "==============================================================================================================\n",
+ "Layer (type:depth-idx) Output Shape Param #\n",
+ "==============================================================================================================\n",
+ "ConvTransformer [1, 58, 89] --\n",
+ "├─ImageEncoder: 1-1 [1, 896, 128] --\n",
+ "│ └─ConvNext: 2-1 [1, 128, 7, 128] --\n",
+ "│ │ └─Conv2d: 3-1 [1, 16, 56, 1024] 800\n",
+ "│ │ └─ModuleList: 3-2 -- --\n",
+ "│ │ │ └─ModuleList: 4-1 -- 42,400\n",
+ "│ │ │ └─ModuleList: 4-2 -- 162,624\n",
+ "│ │ │ └─ModuleList: 4-3 -- 1,089,280\n",
+ "│ │ └─Identity: 3-3 [1, 128, 7, 128] --\n",
+ "│ │ └─LayerNorm: 3-4 [1, 128, 7, 128] 128\n",
+ "│ └─AxialPositionalEmbeddingImage: 2-2 [1, 128, 7, 128] --\n",
+ "│ │ └─AxialPositionalEmbedding: 3-5 [1, 896, 128] 8,640\n",
+ "├─TextDecoder: 1-2 [1, 58, 89] --\n",
+ "│ └─Embedding: 2-3 [1, 89, 128] 7,424\n",
+ "│ └─PositionalEncoding: 2-4 [1, 89, 128] --\n",
+ "│ │ └─Dropout: 3-6 [1, 89, 128] --\n",
+ "│ └─Decoder: 2-5 [1, 89, 128] --\n",
+ "│ │ └─ModuleList: 3-7 -- --\n",
+ "│ │ │ └─DecoderBlock: 4-4 [1, 89, 128] 525,568\n",
+ "│ │ │ └─DecoderBlock: 4-5 [1, 89, 128] 525,568\n",
+ "│ │ │ └─DecoderBlock: 4-6 [1, 89, 128] 525,568\n",
+ "│ │ │ └─DecoderBlock: 4-7 [1, 89, 128] 525,568\n",
+ "│ │ │ └─DecoderBlock: 4-8 [1, 89, 128] 525,568\n",
+ "│ │ │ └─DecoderBlock: 4-9 [1, 89, 128] 525,568\n",
+ "│ │ │ └─DecoderBlock: 4-10 [1, 89, 128] 525,568\n",
+ "│ │ │ └─DecoderBlock: 4-11 [1, 89, 128] 525,568\n",
+ "│ │ │ └─DecoderBlock: 4-12 [1, 89, 128] 525,568\n",
+ "│ │ │ └─DecoderBlock: 4-13 [1, 89, 128] 525,568\n",
+ "│ │ └─LayerNorm: 3-8 [1, 89, 128] 256\n",
+ "│ └─Linear: 2-6 [1, 89, 58] 7,482\n",
+ "==============================================================================================================\n",
+ "Total params: 6,574,714\n",
+ "Trainable params: 6,574,714\n",
"Non-trainable params: 0\n",
- "Total mult-adds (G): 8.47\n",
- "=========================================================================================================\n",
+ "Total mult-adds (G): 8.45\n",
+ "==============================================================================================================\n",
"Input size (MB): 0.23\n",
- "Forward/backward pass size (MB): 442.25\n",
- "Params size (MB): 40.78\n",
- "Estimated Total Size (MB): 483.26\n",
- "========================================================================================================="
+ "Forward/backward pass size (MB): 330.38\n",
+ "Params size (MB): 26.30\n",
+ "Estimated Total Size (MB): 356.91\n",
+ "=============================================================================================================="
]
},
- "execution_count": 16,
+ "execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}