From a2a3133ed5da283888efbdb9924d0e3733c274c8 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 9 May 2021 18:50:55 +0200 Subject: tranformer layer done --- notebooks/00-scratch-pad.ipynb | 246 +++++++++++++++++++++++++++++------------ 1 file changed, 175 insertions(+), 71 deletions(-) (limited to 'notebooks/00-scratch-pad.ipynb') diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb index b6ec2c8..0a5e2f3 100644 --- a/notebooks/00-scratch-pad.ipynb +++ b/notebooks/00-scratch-pad.ipynb @@ -55,6 +55,181 @@ ")" ] }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks.encoders.efficientnet import EfficientNet" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "en = EfficientNet()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==========================================================================================\n", + "Layer (type:depth-idx) Output Shape Param #\n", + "==========================================================================================\n", + "├─Sequential: 1-1 [-1, 256, 18, 20] --\n", + "| └─ConvNorm: 2-1 [-1, 32, 288, 320] --\n", + "| | └─Sequential: 3-1 [-1, 32, 288, 320] 352\n", + "| └─InvertedResidulaBlock: 2-2 [-1, 16, 288, 320] --\n", + "| | └─Sequential: 3-2 [-1, 16, 288, 320] 1,448\n", + "| └─InvertedResidulaBlock: 2-3 [-1, 24, 144, 160] --\n", + "| | └─ConvNorm: 3-3 [-1, 96, 288, 320] 14,016\n", + "| | └─Sequential: 3-4 [-1, 24, 144, 160] 4,276\n", + "| └─InvertedResidulaBlock: 2-4 [-1, 24, 144, 160] --\n", + "| | └─ConvNorm: 3-5 [-1, 144, 144, 160] 31,392\n", + "| | └─Sequential: 3-6 [-1, 24, 144, 160] 6,966\n", + "| └─InvertedResidulaBlock: 2-5 [-1, 40, 72, 80] --\n", + "| | └─ConvNorm: 3-7 [-1, 144, 144, 160] 31,392\n", + "| | └─Sequential: 3-8 [-1, 40, 72, 80] 11,606\n", + "| └─InvertedResidulaBlock: 2-6 [-1, 40, 72, 80] --\n", + "| | └─ConvNorm: 3-9 [-1, 240, 72, 80] 86,880\n", + "| | └─Sequential: 3-10 [-1, 40, 72, 80] 21,210\n", + "| └─InvertedResidulaBlock: 2-7 [-1, 80, 36, 40] --\n", + "| | └─ConvNorm: 3-11 [-1, 240, 72, 80] 86,880\n", + "| | └─Sequential: 3-12 [-1, 80, 36, 40] 27,050\n", + "| └─InvertedResidulaBlock: 2-8 [-1, 80, 36, 40] --\n", + "| | └─ConvNorm: 3-13 [-1, 480, 36, 40] 346,560\n", + "| | └─Sequential: 3-14 [-1, 80, 36, 40] 63,540\n", + "| └─InvertedResidulaBlock: 2-9 [-1, 80, 36, 40] --\n", + "| | └─ConvNorm: 3-15 [-1, 480, 36, 40] 346,560\n", + "| | └─Sequential: 3-16 [-1, 80, 36, 40] 63,540\n", + "| └─InvertedResidulaBlock: 2-10 [-1, 112, 36, 40] --\n", + "| | └─ConvNorm: 3-17 [-1, 480, 36, 40] 346,560\n", + "| | └─Sequential: 3-18 [-1, 112, 36, 40] 86,644\n", + "| └─InvertedResidulaBlock: 2-11 [-1, 112, 36, 40] --\n", + "| | └─ConvNorm: 3-19 [-1, 672, 36, 40] 678,720\n", + "| | └─Sequential: 3-20 [-1, 112, 36, 40] 131,964\n", + "| └─InvertedResidulaBlock: 2-12 [-1, 112, 36, 40] --\n", + "| | └─ConvNorm: 3-21 [-1, 672, 36, 40] 678,720\n", + "| | └─Sequential: 3-22 [-1, 112, 36, 40] 131,964\n", + "| └─InvertedResidulaBlock: 2-13 [-1, 192, 18, 20] --\n", + "| | └─ConvNorm: 3-23 [-1, 672, 36, 40] 678,720\n", + "| | └─Sequential: 3-24 [-1, 192, 18, 20] 185,884\n", + "| └─InvertedResidulaBlock: 2-14 [-1, 192, 18, 20] --\n", + "| | └─ConvNorm: 3-25 [-1, 1152, 18, 20] 1,992,960\n", + "| | └─Sequential: 3-26 [-1, 192, 18, 20] 364,464\n", + "| └─InvertedResidulaBlock: 2-15 [-1, 192, 18, 20] --\n", + "| | └─ConvNorm: 3-27 [-1, 1152, 18, 20] 1,992,960\n", + "| | └─Sequential: 3-28 [-1, 192, 18, 20] 364,464\n", + "| └─InvertedResidulaBlock: 2-16 [-1, 192, 18, 20] --\n", + "| | └─ConvNorm: 3-29 [-1, 1152, 18, 20] 1,992,960\n", + "| | └─Sequential: 3-30 [-1, 192, 18, 20] 364,464\n", + "| └─InvertedResidulaBlock: 2-17 [-1, 320, 18, 20] --\n", + "| | └─ConvNorm: 3-31 [-1, 1152, 18, 20] 1,992,960\n", + "| | └─Sequential: 3-32 [-1, 320, 18, 20] 493,744\n", + "| └─ConvNorm: 2-18 [-1, 256, 18, 20] --\n", + "| | └─Sequential: 3-33 [-1, 256, 18, 20] 82,432\n", + "==========================================================================================\n", + "Total params: 13,704,252\n", + "Trainable params: 13,704,252\n", + "Non-trainable params: 0\n", + "Total mult-adds (G): 1.23\n", + "==========================================================================================\n", + "Input size (MB): 1.41\n", + "Forward/backward pass size (MB): 111.45\n", + "Params size (MB): 52.28\n", + "Estimated Total Size (MB): 165.13\n", + "==========================================================================================\n" + ] + }, + { + "data": { + "text/plain": [ + "==========================================================================================\n", + "Layer (type:depth-idx) Output Shape Param #\n", + "==========================================================================================\n", + "├─Sequential: 1-1 [-1, 256, 18, 20] --\n", + "| └─ConvNorm: 2-1 [-1, 32, 288, 320] --\n", + "| | └─Sequential: 3-1 [-1, 32, 288, 320] 352\n", + "| └─InvertedResidulaBlock: 2-2 [-1, 16, 288, 320] --\n", + "| | └─Sequential: 3-2 [-1, 16, 288, 320] 1,448\n", + "| └─InvertedResidulaBlock: 2-3 [-1, 24, 144, 160] --\n", + "| | └─ConvNorm: 3-3 [-1, 96, 288, 320] 14,016\n", + "| | └─Sequential: 3-4 [-1, 24, 144, 160] 4,276\n", + "| └─InvertedResidulaBlock: 2-4 [-1, 24, 144, 160] --\n", + "| | └─ConvNorm: 3-5 [-1, 144, 144, 160] 31,392\n", + "| | └─Sequential: 3-6 [-1, 24, 144, 160] 6,966\n", + "| └─InvertedResidulaBlock: 2-5 [-1, 40, 72, 80] --\n", + "| | └─ConvNorm: 3-7 [-1, 144, 144, 160] 31,392\n", + "| | └─Sequential: 3-8 [-1, 40, 72, 80] 11,606\n", + "| └─InvertedResidulaBlock: 2-6 [-1, 40, 72, 80] --\n", + "| | └─ConvNorm: 3-9 [-1, 240, 72, 80] 86,880\n", + "| | └─Sequential: 3-10 [-1, 40, 72, 80] 21,210\n", + "| └─InvertedResidulaBlock: 2-7 [-1, 80, 36, 40] --\n", + "| | └─ConvNorm: 3-11 [-1, 240, 72, 80] 86,880\n", + "| | └─Sequential: 3-12 [-1, 80, 36, 40] 27,050\n", + "| └─InvertedResidulaBlock: 2-8 [-1, 80, 36, 40] --\n", + "| | └─ConvNorm: 3-13 [-1, 480, 36, 40] 346,560\n", + "| | └─Sequential: 3-14 [-1, 80, 36, 40] 63,540\n", + "| └─InvertedResidulaBlock: 2-9 [-1, 80, 36, 40] --\n", + "| | └─ConvNorm: 3-15 [-1, 480, 36, 40] 346,560\n", + "| | └─Sequential: 3-16 [-1, 80, 36, 40] 63,540\n", + "| └─InvertedResidulaBlock: 2-10 [-1, 112, 36, 40] --\n", + "| | └─ConvNorm: 3-17 [-1, 480, 36, 40] 346,560\n", + "| | └─Sequential: 3-18 [-1, 112, 36, 40] 86,644\n", + "| └─InvertedResidulaBlock: 2-11 [-1, 112, 36, 40] --\n", + "| | └─ConvNorm: 3-19 [-1, 672, 36, 40] 678,720\n", + "| | └─Sequential: 3-20 [-1, 112, 36, 40] 131,964\n", + "| └─InvertedResidulaBlock: 2-12 [-1, 112, 36, 40] --\n", + "| | └─ConvNorm: 3-21 [-1, 672, 36, 40] 678,720\n", + "| | └─Sequential: 3-22 [-1, 112, 36, 40] 131,964\n", + "| └─InvertedResidulaBlock: 2-13 [-1, 192, 18, 20] --\n", + "| | └─ConvNorm: 3-23 [-1, 672, 36, 40] 678,720\n", + "| | └─Sequential: 3-24 [-1, 192, 18, 20] 185,884\n", + "| └─InvertedResidulaBlock: 2-14 [-1, 192, 18, 20] --\n", + "| | └─ConvNorm: 3-25 [-1, 1152, 18, 20] 1,992,960\n", + "| | └─Sequential: 3-26 [-1, 192, 18, 20] 364,464\n", + "| └─InvertedResidulaBlock: 2-15 [-1, 192, 18, 20] --\n", + "| | └─ConvNorm: 3-27 [-1, 1152, 18, 20] 1,992,960\n", + "| | └─Sequential: 3-28 [-1, 192, 18, 20] 364,464\n", + "| └─InvertedResidulaBlock: 2-16 [-1, 192, 18, 20] --\n", + "| | └─ConvNorm: 3-29 [-1, 1152, 18, 20] 1,992,960\n", + "| | └─Sequential: 3-30 [-1, 192, 18, 20] 364,464\n", + "| └─InvertedResidulaBlock: 2-17 [-1, 320, 18, 20] --\n", + "| | └─ConvNorm: 3-31 [-1, 1152, 18, 20] 1,992,960\n", + "| | └─Sequential: 3-32 [-1, 320, 18, 20] 493,744\n", + "| └─ConvNorm: 2-18 [-1, 256, 18, 20] --\n", + "| | └─Sequential: 3-33 [-1, 256, 18, 20] 82,432\n", + "==========================================================================================\n", + "Total params: 13,704,252\n", + "Trainable params: 13,704,252\n", + "Non-trainable params: 0\n", + "Total mult-adds (G): 1.23\n", + "==========================================================================================\n", + "Input size (MB): 1.41\n", + "Forward/backward pass size (MB): 111.45\n", + "Params size (MB): 52.28\n", + "Estimated Total Size (MB): 165.13\n", + "==========================================================================================" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "summary(en, (1, 576, 640))" + ] + }, { "cell_type": "code", "execution_count": 28, @@ -407,77 +582,6 @@ "efficient_transformer()" ] }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "8" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "len(list(filter(lambda x: x == \"a\", (\"a\", \"c\") * 8)))" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "ModuleList(\n", - " (0): ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " )\n", - " (1): ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " )\n", - " (2): ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " )\n", - " (3): ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " )\n", - " (4): ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " )\n", - " (5): ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " )\n", - " (6): ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " )\n", - " (7): ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " )\n", - " (8): ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " )\n", - " (9): ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " )\n", - ")" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "nn.ModuleList([nn.ModuleList([nn.Linear(10, 10)]) for _ in range(10)])" - ] - }, { "cell_type": "code", "execution_count": 2, -- cgit v1.2.3-70-g09d2