summaryrefslogtreecommitdiff
path: root/notebooks/04-quantizer.ipynb
blob: 66c9c2683f22230d14bad7b88831feb036b03dfe (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "7c02ae76-b540-4b16-9492-e9210b3b9249",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICE'] = ''\n",
    "import random\n",
    "\n",
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import numpy as np\n",
    "from omegaconf import OmegaConf\n",
    "\n",
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "\n",
    "from importlib.util import find_spec\n",
    "if find_spec(\"text_recognizer\") is None:\n",
    "    import sys\n",
    "    sys.path.append('..')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ccdb6dde-47e5-429a-88f2-0764fb7e259a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from hydra import compose, initialize\n",
    "from omegaconf import OmegaConf\n",
    "from hydra.utils import instantiate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "3cf50475-39f2-4642-a7d1-5bcbc0a036f7",
   "metadata": {},
   "outputs": [],
   "source": [
    "path = \"../training/conf/network/quantizer.yaml\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e52ecb01-c975-4e55-925d-1182c7aea473",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open(path, \"rb\") as f:\n",
    "    cfg = OmegaConf.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f939aa37-7b1d-45cc-885c-323c4540bda1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'_target_': 'text_recognizer.networks.quantizer.quantizer.VectorQuantizer', 'input_dim': 192, 'codebook': {'_target_': 'text_recognizer.networks.quantizer.codebook.CosineSimilarityCodebook', 'dim': 16, 'codebook_size': 2048, 'kmeans_init': True, 'kmeans_iters': 10, 'decay': 0.8, 'eps': 1e-05, 'threshold_dead': 2}, 'commitment': 1.0}"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "cfg"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "aaeab329-aeb0-4a1b-aa35-5a2aab81b1d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "vq = instantiate(cfg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "7a0c9f4f-3d95-4722-9212-915a4b9ed096",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "284ec8aa-43a0-4e59-a86f-91bab6c97dca",
   "metadata": {},
   "outputs": [],
   "source": [
    "t = torch.randn(2, 192, 18, 20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "d3f6dad0-33f9-4f80-b514-dd0f71c8b93a",
   "metadata": {},
   "outputs": [],
   "source": [
    "(tt, ii, l) = vq(t)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "80fd228a-d9a2-4334-ab26-c283d215a456",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 192, 18, 20])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tt.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "22d04553-fd12-4f6c-8a43-9105083b0b82",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([2, 192, 360])"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tt.flatten(start_dim=2).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "b79649ce-3623-4dd2-9a38-bef7bc5c9ac1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[ 484,  223,  752,  735,  199, 1428,  238,   65, 1357,  950, 1792,   87,\n",
       "         1006,  264, 1425,  357,  375,  131,  958,  807, 1903,  577,  552,  104,\n",
       "          278,  495,  578, 1415, 1737,  593, 1442,  269,   62, 1274,  346,  667,\n",
       "         1296,  421,  455, 1151,  475,  880, 1112,  279,   44,  171,  101,  638,\n",
       "          719, 1502,  185,  103,  157,  802,  123,   30,  877,  691,   92,  151,\n",
       "           95,  442, 1782,  819,  466,  360, 1376,  769,  636, 1212,  840,  467,\n",
       "          181,  933, 1708,  706,  521,  423, 1009, 1337,  527,  382,  431,  789,\n",
       "           89,   97,  138,   45,  448,  949,  767,  549,  613,  603,  222,  620,\n",
       "           57,  235,  488,  938, 1163, 1105,    0,  818, 1663,  220, 1785,  132,\n",
       "          270, 1103,  503,  388,  655,  102,  122, 1914,   52,  166,  361,   40,\n",
       "          996,  399, 1075, 1340,  857,  256,  366,  773,  324,  110,  820, 1855,\n",
       "          280,    8, 1000,  836, 1022, 1590, 1121, 1605, 1523,  168, 1135,   91,\n",
       "          156,  372, 1134,  507, 1307,  344,  545,  537,  401,  611,  479,  581,\n",
       "          990, 1147,  209, 1229,  211,  612, 1188,  822,  571,  221,  625,  206,\n",
       "          233,  225, 1472,   81,  440, 1059,   54,  539, 1156,   74, 1902,  704,\n",
       "          217, 1094, 1189, 1694,  273, 1190,  327, 1629,  947,   27, 1885,   51,\n",
       "          427,  460, 1065,   83,  100,  642,   29,  795,  366, 1122, 1264,  133,\n",
       "          300,   28, 1033, 1628,   41,  194,  130,   44,   75, 1525,  229,  374,\n",
       "          163, 1060,   70, 1472,  619, 1109,  804,  219,  291, 1038, 1269,  408,\n",
       "          251,  876, 1068,  385,   33,  345,   66,  403,  570, 1556, 1895,  230,\n",
       "         1450,  498,  105,  249,  386,  711, 1577,   20,  153,  370,   91,  518,\n",
       "         1097,  109,  688,  966, 1078,  385,  335,  815,  615,  118,  169,  329,\n",
       "          609,  111, 1580,  785,  149,  368,  263,   43, 1261,  289,  308,  195,\n",
       "          622,  184, 1342,  660,  208,  321,  432,  226,  430,  979,  347,  172,\n",
       "         1508,  896,   16,  307,  564,   69,  645,  191,  125,  928,  414, 1288,\n",
       "           90,  152,  218,  113,   88, 1688,  297,  705,  983,  144,  676,  187,\n",
       "         1290,   48,  861,  139,  800,  964, 1584,  217,  140,  449,  500,  441,\n",
       "          285,    5,  296,  325,   96,  751,  134,  137,  107,   76,  348,  302,\n",
       "          150, 1367,  739,  872,  445,   24, 1209,   14,  326,  148,   36, 2046,\n",
       "         1668, 1045,   53,  200,  454,  996,   10, 1069,  504,  630,  958, 1026],\n",
       "        [  60, 1160, 1018, 1137,  509,  824,  182,  576,  890, 1076,  569,  769,\n",
       "           79,  763,  566,  604, 1862,  286,  691,  189, 1604,  251,  771,  436,\n",
       "          361,  715,  328, 1027,  287, 1697,  559, 1014,  582,  239,  299,   23,\n",
       "          891,  459,  682, 1600,  520, 1112, 1898,  142,  874,  244,   99,  126,\n",
       "           61,  129,  162,  331,  961,   55,  136,  179,   47,  292,    1,  164,\n",
       "          453,  532,    7,  934,   42, 1325,  775,  607,  115, 2009,  174, 1187,\n",
       "          654,   63,  127,  186,  426,  349, 1168,  309, 1715, 1108,  160,  433,\n",
       "          525,  901,  379,   19,   28,  108,  425, 1704,  640,   59,   37,  534,\n",
       "          330,  276,  523,  395,   35,  837,  537,   50, 1616, 1140,  641,  487,\n",
       "          246,  327,  905, 1872,  242,  202,  684, 1275, 1098,  546,  243, 1841,\n",
       "          190,  988,   17,  292,  387, 1128,  120,  652,  356,  716,  376, 1228,\n",
       "          494,   32,  252, 1007, 1082,  526,  318,  303,  356,  917,  497,  816,\n",
       "         1795,  220,   25, 1454,  595, 1372,  572,  478,  131,  863,  626,   26,\n",
       "          492,  758,  756, 1146,  175, 1044,  846,  355,  472, 1136,  973,   64,\n",
       "           11,  322,  201,   18,  444,  265,  730,  823,  176,  351, 1037,  205,\n",
       "          475, 1985,  367, 1267, 1048,  918,  146, 1195,  510,  966,  372,  124,\n",
       "           31,  517,  508,   94,  145,   58,  334, 1557,   93,  293,   13,   67,\n",
       "         1193, 1791,   12,  993, 1710,  610,  319,    9,  683,  661, 1941,  697,\n",
       "          501,  463, 1273,  401,  476,    3, 1571,  247,  590, 1198,   62,  405,\n",
       "         1023, 1141, 1777,  831,  259,   71,  135,  651, 1244,   22,  777,  402,\n",
       "         1040, 1303, 1421,  398, 1382,   70,  275,  549, 1707,  733,  737, 1467,\n",
       "          790,  257,  621,   80, 1254,  592,  428,  272,  183,   73,  799,  583,\n",
       "          506,  785,  210,  248,  896,   56,   49,  184,   21, 1058,   78,  736,\n",
       "          412,  543,  175,  714, 1577,  531,  645,  660,  792,   86, 1925,  378,\n",
       "         1467,  450,   98, 1074,  159,  653,  188, 1006,  320,  362, 1599, 1508,\n",
       "          182,  709,  180,  227,  398,  718,  680,  337,  310,  294, 1649,  413,\n",
       "          565,  546,  106, 1292,   68,  237,  290,    6,  797,    4,  701,  245,\n",
       "         1320,  377,  912,  847,  670,   15, 1909,  284,   38,  558, 1473, 1375,\n",
       "          241, 1473,   77,  281,  620,  360,  312,  437,  262,  416,  435,  796,\n",
       "          474, 1250,    2, 1087,  170,  612,  283,  750,  369,  745,  304,  793]])"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ii"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "618b997c-e6a6-4487-b70c-9d260cb556d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "from torchinfo import summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "25759b7b-8deb-4163-b75d-a1357c9fe88f",
   "metadata": {},
   "outputs": [],
   "source": [
    "summary(net, (2, 1, 224, 224), device=\"cpu\", depth=4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6aa04f07-12d4-4e06-b921-d54367c50a9a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}