summaryrefslogtreecommitdiff
path: root/notebooks/04-efficientnet-transformer.ipynb
blob: 8affa9d4bf093ff80919fe9c2d2d8e8949e2bdd7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7c02ae76-b540-4b16-9492-e9210b3b9249",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICE'] = ''\n",
    "import random\n",
    "\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import numpy as np\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from importlib.util import find_spec\n",
    "if find_spec(\"text_recognizer\") is None:\n",
    "    import sys\n",
    "    sys.path.append('..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ccdb6dde-47e5-429a-88f2-0764fb7e259a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from hydra import compose, initialize\n",
    "from omegaconf import OmegaConf\n",
    "from hydra.utils import instantiate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3cf50475-39f2-4642-a7d1-5bcbc0a036f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"../training/conf/experiment/cnn_htr_char_lines.yaml\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e52ecb01-c975-4e55-925d-1182c7aea473",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(path, \"rb\") as f:\n",
    "    cfg = OmegaConf.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "f939aa37-7b1d-45cc-885c-323c4540bda1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'defaults': [{'override /mapping': None}, {'override /criterion': None}, {'override /datamodule': None}, {'override /network': None}, {'override /model': None}, {'override /lr_schedulers': None}, {'override /optimizers': None}], 'ignore_index': 3, 'num_classes': 58, 'max_output_len': 89, 'criterion': {'_target_': 'torch.nn.CrossEntropyLoss', 'ignore_index': 3}, 'mapping': {'_target_': 'text_recognizer.data.emnist_mapping.EmnistMapping'}, 'callbacks': {'stochastic_weight_averaging': {'_target_': 'pytorch_lightning.callbacks.StochasticWeightAveraging', 'swa_epoch_start': 0.8, 'swa_lrs': 0.05, 'annealing_epochs': 10, 'annealing_strategy': 'cos', 'device': None}}, 'optimizers': {'madgrad': {'_target_': 'madgrad.MADGRAD', 'lr': 0.0003, 'momentum': 0.9, 'weight_decay': 0, 'eps': 1e-06, 'parameters': 'network'}}, 'lr_schedulers': {'network': {'_target_': 'torch.optim.lr_scheduler.ReduceLROnPlateau', 'mode': 'min', 'factor': 0.5, 'patience': 10, 'threshold': 0.0001, 'threshold_mode': 'rel', 'cooldown': 0, 'min_lr': 1e-06, 'eps': 1e-08, 'interval': 'epoch', 'monitor': 'val/loss'}}, 'datamodule': {'_target_': 'text_recognizer.data.iam_lines.IAMLines', 'batch_size': 16, 'num_workers': 12, 'train_fraction': 0.8, 'augment': True, 'pin_memory': True, 'word_pieces': False}, 'network': {'_target_': 'text_recognizer.networks.conv_transformer.ConvTransformer', 'input_dims': [1, 56, 1024], 'hidden_dim': 128, 'encoder_dim': 1280, 'dropout_rate': 0.2, 'num_classes': 58, 'pad_index': 3, 'encoder': {'_target_': 'text_recognizer.networks.encoders.efficientnet.EfficientNet', 'arch': 'b0', 'out_channels': 1280, 'stochastic_dropout_rate': 0.2, 'bn_momentum': 0.99, 'bn_eps': 0.001}, 'decoder': {'_target_': 'text_recognizer.networks.transformer.Decoder', 'dim': 128, 'depth': 3, 'num_heads': 4, 'attn_fn': 'text_recognizer.networks.transformer.attention.Attention', 'attn_kwargs': {'dim_head': 32, 'dropout_rate': 0.2}, 'norm_fn': 'text_recognizer.networks.transformer.norm.ScaleNorm', 'ff_fn': 'text_recognizer.networks.transformer.mlp.FeedForward', 'ff_kwargs': {'dim_out': None, 'expansion_factor': 4, 'glu': True, 'dropout_rate': 0.2}, 'cross_attend': True, 'pre_norm': True, 'rotary_emb': {'_target_': 'text_recognizer.networks.transformer.positional_encodings.rotary_embedding.RotaryEmbedding', 'dim': 32}}}, 'model': {'_target_': 'text_recognizer.models.transformer.TransformerLitModel', 'max_output_len': 89, 'start_token': '<s>', 'end_token': '<e>', 'pad_token': '<p>'}, 'trainer': {'_target_': 'pytorch_lightning.Trainer', 'stochastic_weight_avg': True, 'auto_scale_batch_size': 'binsearch', 'auto_lr_find': False, 'gradient_clip_val': 0.5, 'fast_dev_run': False, 'gpus': 1, 'precision': 16, 'max_epochs': 420, 'terminate_on_nan': True, 'weights_summary': None, 'limit_train_batches': 1.0, 'limit_val_batches': 1.0, 'limit_test_batches': 1.0, 'resume_from_checkpoint': None, 'accumulate_grad_batches': 8, 'overfit_batches': 0}, 'summary': [[1, 1, 56, 1024], [1, 89]]}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cfg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "aaeab329-aeb0-4a1b-aa35-5a2aab81b1d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "net = instantiate(cfg.network)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "618b997c-e6a6-4487-b70c-9d260cb556d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchinfo import summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "25759b7b-8deb-4163-b75d-a1357c9fe88f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[[-0.1375,  0.1336, -0.9556,  ...,  1.3900, -0.1311,  0.0424],\n",
      "          [ 0.3474,  0.3920, -0.7718,  ...,  1.3900, -0.1311,  0.0424],\n",
      "          [ 0.5129,  0.5296, -0.5113,  ...,  1.3900, -0.1311,  0.0424],\n",
      "          ...,\n",
      "          [-0.4101, -0.5381,  0.8383,  ...,  1.3900, -0.1311,  0.0424],\n",
      "          [-0.4902, -0.4792,  0.9942,  ...,  1.3900, -0.1311,  0.0424],\n",
      "          [-0.1197, -0.2727,  1.0515,  ...,  1.3900, -0.1311,  0.0424]],\n",
      "\n",
      "         [[-0.5087, -0.9362,  0.4579,  ..., -0.9192,  0.5128, -0.4577],\n",
      "          [-0.9311, -0.7565,  0.2538,  ..., -0.9192,  0.5128, -0.4577],\n",
      "          [-0.4974, -0.3438,  0.0245,  ..., -0.9192,  0.5128, -0.4577],\n",
      "          ...,\n",
      "          [ 0.9153,  0.2433, -0.7304,  ..., -0.9192,  0.5128, -0.4577],\n",
      "          [ 0.3510, -0.2775, -0.7340,  ..., -0.9192,  0.5128, -0.4577],\n",
      "          [-0.5360, -0.7128, -0.6648,  ..., -0.9192,  0.5128, -0.4577]],\n",
      "\n",
      "         [[-0.3042,  0.4230, -0.1477,  ...,  0.5753,  0.2627, -0.9272],\n",
      "          [-0.6520,  0.1020, -0.3723,  ...,  0.5753,  0.2627, -0.9272],\n",
      "          [-0.4004, -0.2504, -0.5599,  ...,  0.5753,  0.2627, -0.9272],\n",
      "          ...,\n",
      "          [ 0.6519,  0.3150, -0.5875,  ...,  0.5753,  0.2627, -0.9272],\n",
      "          [ 0.3030,  0.5633, -0.4083,  ...,  0.5753,  0.2627, -0.9272],\n",
      "          [-0.3245,  0.6381, -0.1886,  ...,  0.5753,  0.2627, -0.9272]],\n",
      "\n",
      "         [[-0.3549,  0.1346,  0.7488,  ..., -0.0206,  0.3056, -0.3023],\n",
      "          [-0.7494,  0.0327,  0.8030,  ..., -0.0206,  0.3056, -0.3023],\n",
      "          [-0.4549, -0.0792,  0.7776,  ..., -0.0206,  0.3056, -0.3023],\n",
      "          ...,\n",
      "          [ 0.7482,  0.0998, -0.0949,  ..., -0.0206,  0.3056, -0.3023],\n",
      "          [ 0.3424,  0.1788, -0.3386,  ..., -0.0206,  0.3056, -0.3023],\n",
      "          [-0.3782,  0.2027, -0.5487,  ..., -0.0206,  0.3056, -0.3023]]]])\n",
      "tensor([[[[-0.8664, -0.5805,  0.3959,  ..., -0.3215, -1.1297, -1.0401],\n",
      "          [-0.8629, -0.5843,  0.3928,  ..., -0.3136, -1.1305, -1.0422],\n",
      "          [-0.8580, -0.5812,  0.3899,  ..., -0.3078, -1.1307, -1.0447],\n",
      "          ...,\n",
      "          [-0.8821, -0.5753,  0.3974,  ..., -0.3362, -1.1056, -1.0572],\n",
      "          [-0.8827, -0.5753,  0.3972,  ..., -0.3362, -1.1055, -1.0570],\n",
      "          [-0.8833, -0.5757,  0.3969,  ..., -0.3359, -1.1051, -1.0570]],\n",
      "\n",
      "         [[ 0.4034, -0.0638,  0.4687,  ...,  0.2240, -0.5357, -0.2091],\n",
      "          [ 0.3946, -0.0698,  0.4697,  ...,  0.2241, -0.5370, -0.2128],\n",
      "          [ 0.3872, -0.0761,  0.4697,  ...,  0.2269, -0.5366, -0.2166],\n",
      "          ...,\n",
      "          [ 0.4107, -0.0766,  0.4126,  ...,  0.2256, -0.5160, -0.2184],\n",
      "          [ 0.4108, -0.0764,  0.4126,  ...,  0.2257, -0.5160, -0.2182],\n",
      "          [ 0.4109, -0.0761,  0.4126,  ...,  0.2256, -0.5160, -0.2181]],\n",
      "\n",
      "         [[ 1.1814, -0.3463,  0.7075,  ...,  0.3644,  0.1285,  0.1048],\n",
      "          [ 1.1817, -0.3528,  0.7070,  ...,  0.3619,  0.1285,  0.1008],\n",
      "          [ 1.1820, -0.3611,  0.7079,  ...,  0.3566,  0.1281,  0.0932],\n",
      "          ...,\n",
      "          [ 1.1625, -0.3667,  0.7248,  ...,  0.3525,  0.1328,  0.0854],\n",
      "          [ 1.1626, -0.3665,  0.7251,  ...,  0.3525,  0.1329,  0.0860],\n",
      "          [ 1.1626, -0.3658,  0.7254,  ...,  0.3526,  0.1329,  0.0866]],\n",
      "\n",
      "         [[-0.3258, -0.0492, -0.1981,  ..., -0.0680, -0.1737,  0.8813],\n",
      "          [-0.3236, -0.0472, -0.1983,  ..., -0.0716, -0.1781,  0.8754],\n",
      "          [-0.3207, -0.0424, -0.2006,  ..., -0.0735, -0.1775,  0.8708],\n",
      "          ...,\n",
      "          [-0.3556, -0.0327, -0.2103,  ..., -0.0663, -0.1694,  0.9025],\n",
      "          [-0.3566, -0.0330, -0.2102,  ..., -0.0662, -0.1690,  0.9030],\n",
      "          [-0.3577, -0.0334, -0.2098,  ..., -0.0664, -0.1692,  0.9035]]]])\n",
      "tensor([[[[ 0.4187, -0.4013,  0.1790,  ..., -0.7965, -0.4432, -0.4109],\n",
      "          [ 0.4706, -0.2460,  0.2507,  ..., -0.7946, -0.4455, -0.4032],\n",
      "          [ 0.0920, -0.0103,  0.3006,  ..., -0.7921, -0.4488, -0.3960],\n",
      "          ...,\n",
      "          [-0.4408, -0.0365,  0.1528,  ..., -0.8183, -0.4272, -0.3973],\n",
      "          [ 0.0095, -0.2703,  0.0542,  ..., -0.8183, -0.4264, -0.3973],\n",
      "          [ 0.4515, -0.4208, -0.0497,  ..., -0.8187, -0.4256, -0.3972]],\n",
      "\n",
      "         [[-0.2298,  0.5305, -1.6682,  ..., -0.9803,  0.4137, -0.7174],\n",
      "          [-1.2319, -0.1384, -1.7214,  ..., -0.9782,  0.4137, -0.7245],\n",
      "          [-1.0939, -0.7782, -1.6090,  ..., -0.9785,  0.4124, -0.7308],\n",
      "          ...,\n",
      "          [ 1.3048,  0.8945,  0.3699,  ..., -0.9897,  0.3753, -0.7444],\n",
      "          [ 0.9834,  1.2249,  0.8801,  ..., -0.9897,  0.3753, -0.7445],\n",
      "          [-0.2414,  1.1782,  1.3028,  ..., -0.9894,  0.3752, -0.7445]],\n",
      "\n",
      "         [[-0.1399,  0.4728,  0.7512,  ..., -0.4790,  0.0188,  1.0488],\n",
      "          [ 0.3772,  0.0967,  0.6151,  ..., -0.4738,  0.0168,  1.0455],\n",
      "          [ 0.5529, -0.3050,  0.4153,  ..., -0.4704,  0.0122,  1.0429],\n",
      "          ...,\n",
      "          [-0.4351,  0.3780, -0.6546,  ..., -0.5326,  0.0208,  1.0425],\n",
      "          [-0.5327,  0.6673, -0.7752,  ..., -0.5330,  0.0208,  1.0421],\n",
      "          [-0.1395,  0.7507, -0.8185,  ..., -0.5332,  0.0209,  1.0415]],\n",
      "\n",
      "         [[ 0.4693, -0.3338,  0.6982,  ...,  0.4498,  1.1733, -0.4666],\n",
      "          [ 0.9968, -0.1894,  0.7689,  ...,  0.4459,  1.1737, -0.4665],\n",
      "          [ 0.6078,  0.0178,  0.7619,  ...,  0.4457,  1.1731, -0.4681],\n",
      "          ...,\n",
      "          [-0.9903, -0.0882, -0.0023,  ...,  0.4639,  1.1550, -0.4838],\n",
      "          [-0.4527, -0.2828, -0.2390,  ...,  0.4640,  1.1545, -0.4840],\n",
      "          [ 0.5021, -0.3899, -0.4519,  ...,  0.4636,  1.1543, -0.4840]]]])\n",
      "tensor([[[[-0.5996, -0.8818,  1.0916,  ...,  1.3350, -0.6233, -0.4210],\n",
      "          [-0.5956, -0.8787,  1.0899,  ...,  1.3421, -0.6181, -0.4139],\n",
      "          [-0.5935, -0.8731,  1.0919,  ...,  1.3472, -0.6063, -0.4136],\n",
      "          ...,\n",
      "          [-0.6075, -0.8807,  1.0738,  ...,  1.3215, -0.5991, -0.4371],\n",
      "          [-0.6073, -0.8811,  1.0738,  ...,  1.3200, -0.5991, -0.4377],\n",
      "          [-0.6068, -0.8812,  1.0731,  ...,  1.3187, -0.5991, -0.4378]],\n",
      "\n",
      "         [[-0.5803, -0.4059, -0.0754,  ...,  0.6803, -0.3709, -0.1859],\n",
      "          [-0.5851, -0.3976, -0.0814,  ...,  0.6705, -0.3691, -0.1860],\n",
      "          [-0.5859, -0.3901, -0.0849,  ...,  0.6577, -0.3676, -0.1925],\n",
      "          ...,\n",
      "          [-0.5888, -0.4312, -0.0346,  ...,  0.6210, -0.3170, -0.2286],\n",
      "          [-0.5888, -0.4318, -0.0342,  ...,  0.6211, -0.3165, -0.2287],\n",
      "          [-0.5889, -0.4323, -0.0338,  ...,  0.6214, -0.3162, -0.2285]],\n",
      "\n",
      "         [[-1.3079,  0.6818,  0.7237,  ..., -0.2953,  1.0035,  0.0144],\n",
      "          [-1.3053,  0.7041,  0.7281,  ..., -0.3037,  1.0063,  0.0130],\n",
      "          [-1.3047,  0.7197,  0.7334,  ..., -0.3116,  1.0118,  0.0097],\n",
      "          ...,\n",
      "          [-1.3425,  0.6961,  0.7392,  ..., -0.2939,  1.0508,  0.0380],\n",
      "          [-1.3430,  0.6959,  0.7389,  ..., -0.2938,  1.0508,  0.0380],\n",
      "          [-1.3436,  0.6961,  0.7390,  ..., -0.2936,  1.0504,  0.0379]],\n",
      "\n",
      "         [[-0.1651, -0.4341, -0.5129,  ..., -0.2994,  0.5808,  0.0624],\n",
      "          [-0.1675, -0.4243, -0.5142,  ..., -0.3024,  0.5866,  0.0570],\n",
      "          [-0.1647, -0.4233, -0.5158,  ..., -0.2985,  0.5921,  0.0548],\n",
      "          ...,\n",
      "          [-0.1671, -0.4497, -0.4654,  ..., -0.3217,  0.5996,  0.0394],\n",
      "          [-0.1665, -0.4494, -0.4648,  ..., -0.3218,  0.5996,  0.0393],\n",
      "          [-0.1666, -0.4481, -0.4641,  ..., -0.3226,  0.5996,  0.0394]]]])\n",
      "tensor([[[[ 0.7548, -0.7519,  0.9516,  ..., -0.4464, -0.3636, -0.2868],\n",
      "          [-0.1622, -0.7571,  0.5601,  ..., -0.4323, -0.3665, -0.2852],\n",
      "          [-0.9178, -0.5404,  0.1087,  ..., -0.4212, -0.3705, -0.2827],\n",
      "          ...,\n",
      "          [ 0.3565,  0.4814, -1.4442,  ..., -0.4623, -0.3604, -0.2805],\n",
      "          [ 0.9851,  0.0684, -1.4689,  ..., -0.4634, -0.3601, -0.2801],\n",
      "          [ 0.7091, -0.3652, -1.3481,  ..., -0.4641, -0.3598, -0.2800]],\n",
      "\n",
      "         [[ 0.0558, -0.0961, -0.3113,  ...,  0.3085,  0.3596,  0.7701],\n",
      "          [ 0.0225, -0.3077, -0.4186,  ...,  0.3083,  0.3633,  0.7737],\n",
      "          [-0.0396, -0.4271, -0.4848,  ...,  0.3086,  0.3699,  0.7804],\n",
      "          ...,\n",
      "          [ 0.0266,  0.4086, -0.2584,  ...,  0.3113,  0.3957,  0.7807],\n",
      "          [ 0.0362,  0.4019, -0.1022,  ...,  0.3113,  0.3957,  0.7810],\n",
      "          [ 0.0117,  0.2716,  0.0645,  ...,  0.3113,  0.3955,  0.7811]],\n",
      "\n",
      "         [[-0.0074, -0.1536,  0.3226,  ...,  0.6790, -0.0736,  0.4578],\n",
      "          [ 0.0267, -0.2415,  0.4238,  ...,  0.6830, -0.0770,  0.4544],\n",
      "          [ 0.0347, -0.2485,  0.4886,  ...,  0.6840, -0.0802,  0.4572],\n",
      "          ...,\n",
      "          [-0.0258,  0.2343,  0.1960,  ...,  0.7261, -0.0211,  0.5036],\n",
      "          [-0.0052,  0.1651,  0.0384,  ...,  0.7260, -0.0208,  0.5045],\n",
      "          [ 0.0203,  0.0448, -0.1227,  ...,  0.7264, -0.0205,  0.5050]],\n",
      "\n",
      "         [[ 0.3598,  0.9917,  0.2596,  ..., -0.6748,  0.2234,  0.4948],\n",
      "          [-0.3283,  0.9644,  0.2100,  ..., -0.6691,  0.2088,  0.4965],\n",
      "          [-0.7103,  0.6471,  0.1438,  ..., -0.6655,  0.1964,  0.4979],\n",
      "          ...,\n",
      "          [ 0.4529, -0.5427, -0.1871,  ..., -0.7174,  0.2235,  0.4909],\n",
      "          [ 0.7765,  0.0142, -0.2525,  ..., -0.7171,  0.2236,  0.4911],\n",
      "          [ 0.3870,  0.5668, -0.2926,  ..., -0.7172,  0.2239,  0.4913]]]])\n",
      "tensor([[[[-0.4846, -0.0448,  0.7438,  ..., -0.2732,  0.5203, -0.3579],\n",
      "          [-0.4858, -0.0525,  0.7421,  ..., -0.2585,  0.5151, -0.3524],\n",
      "          [-0.4920, -0.0616,  0.7417,  ..., -0.2429,  0.5112, -0.3484],\n",
      "          ...,\n",
      "          [-0.5254, -0.1285,  0.7435,  ..., -0.2306,  0.5152, -0.3257],\n",
      "          [-0.5255, -0.1289,  0.7432,  ..., -0.2314,  0.5157, -0.3253],\n",
      "          [-0.5254, -0.1291,  0.7429,  ..., -0.2321,  0.5158, -0.3250]],\n",
      "\n",
      "         [[-0.4301,  0.8671,  0.5170,  ..., -1.1632, -0.3734, -1.1874],\n",
      "          [-0.4446,  0.8762,  0.5196,  ..., -1.1528, -0.3724, -1.1955],\n",
      "          [-0.4582,  0.8844,  0.5180,  ..., -1.1396, -0.3729, -1.1942],\n",
      "          ...,\n",
      "          [-0.4770,  0.8541,  0.5325,  ..., -1.1392, -0.3525, -1.1223],\n",
      "          [-0.4766,  0.8535,  0.5316,  ..., -1.1390, -0.3523, -1.1217],\n",
      "          [-0.4760,  0.8530,  0.5309,  ..., -1.1390, -0.3519, -1.1214]],\n",
      "\n",
      "         [[-0.6542, -0.1481,  0.0284,  ...,  0.1807,  0.0668, -0.1159],\n",
      "          [-0.6490, -0.1293,  0.0320,  ...,  0.1974,  0.0622, -0.1280],\n",
      "          [-0.6482, -0.1144,  0.0331,  ...,  0.2133,  0.0589, -0.1325],\n",
      "          ...,\n",
      "          [-0.6662, -0.0985,  0.0627,  ...,  0.1447,  0.0864, -0.0778],\n",
      "          [-0.6664, -0.0994,  0.0632,  ...,  0.1434,  0.0864, -0.0774],\n",
      "          [-0.6662, -0.0997,  0.0639,  ...,  0.1421,  0.0864, -0.0775]],\n",
      "\n",
      "         [[-0.1255,  1.4525, -0.4119,  ...,  0.6365,  0.6966,  0.2816],\n",
      "          [-0.1274,  1.4516, -0.4094,  ...,  0.6396,  0.6962,  0.2907],\n",
      "          [-0.1325,  1.4417, -0.4089,  ...,  0.6359,  0.6969,  0.3028],\n",
      "          ...,\n",
      "          [-0.1968,  1.3976, -0.3989,  ...,  0.6016,  0.7498,  0.3160],\n",
      "          [-0.1972,  1.3966, -0.3989,  ...,  0.6020,  0.7502,  0.3163],\n",
      "          [-0.1977,  1.3958, -0.3984,  ...,  0.6025,  0.7504,  0.3163]]]])\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "====================================================================================================\n",
       "Layer (type:depth-idx)                             Output Shape              Param #\n",
       "====================================================================================================\n",
       "ConvTransformer                                    --                        --\n",
       "├─EfficientNet: 1-1                                [1, 1280, 1, 32]          7,142,272\n",
       "├─Sequential: 1-2                                  [1, 128, 32]              163,968\n",
       "├─Embedding: 1-3                                   [1, 89, 128]              7,424\n",
       "├─PositionalEncoding: 1-4                          [1, 89, 128]              --\n",
       "├─Decoder: 1-5                                     [1, 89, 128]              13,176,969\n",
       "├─Linear: 1-6                                      [1, 89, 58]               7,482\n",
       "====================================================================================================\n",
       "Total params: 20,498,115\n",
       "Trainable params: 20,498,115\n",
       "Non-trainable params: 0\n",
       "Total mult-adds (M): 714.86\n",
       "====================================================================================================\n",
       "Input size (MB): 0.23\n",
       "Forward/backward pass size (MB): 184.29\n",
       "Params size (MB): 81.99\n",
       "Estimated Total Size (MB): 266.51\n",
       "===================================================================================================="
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "summary(net, list(map(lambda x: list(x), cfg.summary)), device=\"cpu\", depth=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bf9d8d67-d7d2-4cf7-b166-377a79d5fd70",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}