diff options
Diffstat (limited to 'notebooks/04-quantizer.ipynb')
-rw-r--r-- | notebooks/04-quantizer.ipynb | 295 |
1 files changed, 295 insertions, 0 deletions
diff --git a/notebooks/04-quantizer.ipynb b/notebooks/04-quantizer.ipynb new file mode 100644 index 0000000..66c9c26 --- /dev/null +++ b/notebooks/04-quantizer.ipynb @@ -0,0 +1,295 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "7c02ae76-b540-4b16-9492-e9210b3b9249", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ['CUDA_VISIBLE_DEVICE'] = ''\n", + "import random\n", + "\n", + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import numpy as np\n", + "from omegaconf import OmegaConf\n", + "\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "from importlib.util import find_spec\n", + "if find_spec(\"text_recognizer\") is None:\n", + " import sys\n", + " sys.path.append('..')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "ccdb6dde-47e5-429a-88f2-0764fb7e259a", + "metadata": {}, + "outputs": [], + "source": [ + "from hydra import compose, initialize\n", + "from omegaconf import OmegaConf\n", + "from hydra.utils import instantiate" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "3cf50475-39f2-4642-a7d1-5bcbc0a036f7", + "metadata": {}, + "outputs": [], + "source": [ + "path = \"../training/conf/network/quantizer.yaml\"" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "e52ecb01-c975-4e55-925d-1182c7aea473", + "metadata": {}, + "outputs": [], + "source": [ + "with open(path, \"rb\") as f:\n", + " cfg = OmegaConf.load(f)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f939aa37-7b1d-45cc-885c-323c4540bda1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'_target_': 'text_recognizer.networks.quantizer.quantizer.VectorQuantizer', 'input_dim': 192, 'codebook': {'_target_': 'text_recognizer.networks.quantizer.codebook.CosineSimilarityCodebook', 'dim': 16, 'codebook_size': 2048, 'kmeans_init': True, 'kmeans_iters': 10, 'decay': 0.8, 'eps': 1e-05, 'threshold_dead': 2}, 'commitment': 1.0}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cfg" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "aaeab329-aeb0-4a1b-aa35-5a2aab81b1d0", + "metadata": {}, + "outputs": [], + "source": [ + "vq = instantiate(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "7a0c9f4f-3d95-4722-9212-915a4b9ed096", + "metadata": {}, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "284ec8aa-43a0-4e59-a86f-91bab6c97dca", + "metadata": {}, + "outputs": [], + "source": [ + "t = torch.randn(2, 192, 18, 20)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "d3f6dad0-33f9-4f80-b514-dd0f71c8b93a", + "metadata": {}, + "outputs": [], + "source": [ + "(tt, ii, l) = vq(t)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "80fd228a-d9a2-4334-ab26-c283d215a456", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 192, 18, 20])" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tt.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "22d04553-fd12-4f6c-8a43-9105083b0b82", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 192, 360])" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tt.flatten(start_dim=2).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "b79649ce-3623-4dd2-9a38-bef7bc5c9ac1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[ 484, 223, 752, 735, 199, 1428, 238, 65, 1357, 950, 1792, 87,\n", + " 1006, 264, 1425, 357, 375, 131, 958, 807, 1903, 577, 552, 104,\n", + " 278, 495, 578, 1415, 1737, 593, 1442, 269, 62, 1274, 346, 667,\n", + " 1296, 421, 455, 1151, 475, 880, 1112, 279, 44, 171, 101, 638,\n", + " 719, 1502, 185, 103, 157, 802, 123, 30, 877, 691, 92, 151,\n", + " 95, 442, 1782, 819, 466, 360, 1376, 769, 636, 1212, 840, 467,\n", + " 181, 933, 1708, 706, 521, 423, 1009, 1337, 527, 382, 431, 789,\n", + " 89, 97, 138, 45, 448, 949, 767, 549, 613, 603, 222, 620,\n", + " 57, 235, 488, 938, 1163, 1105, 0, 818, 1663, 220, 1785, 132,\n", + " 270, 1103, 503, 388, 655, 102, 122, 1914, 52, 166, 361, 40,\n", + " 996, 399, 1075, 1340, 857, 256, 366, 773, 324, 110, 820, 1855,\n", + " 280, 8, 1000, 836, 1022, 1590, 1121, 1605, 1523, 168, 1135, 91,\n", + " 156, 372, 1134, 507, 1307, 344, 545, 537, 401, 611, 479, 581,\n", + " 990, 1147, 209, 1229, 211, 612, 1188, 822, 571, 221, 625, 206,\n", + " 233, 225, 1472, 81, 440, 1059, 54, 539, 1156, 74, 1902, 704,\n", + " 217, 1094, 1189, 1694, 273, 1190, 327, 1629, 947, 27, 1885, 51,\n", + " 427, 460, 1065, 83, 100, 642, 29, 795, 366, 1122, 1264, 133,\n", + " 300, 28, 1033, 1628, 41, 194, 130, 44, 75, 1525, 229, 374,\n", + " 163, 1060, 70, 1472, 619, 1109, 804, 219, 291, 1038, 1269, 408,\n", + " 251, 876, 1068, 385, 33, 345, 66, 403, 570, 1556, 1895, 230,\n", + " 1450, 498, 105, 249, 386, 711, 1577, 20, 153, 370, 91, 518,\n", + " 1097, 109, 688, 966, 1078, 385, 335, 815, 615, 118, 169, 329,\n", + " 609, 111, 1580, 785, 149, 368, 263, 43, 1261, 289, 308, 195,\n", + " 622, 184, 1342, 660, 208, 321, 432, 226, 430, 979, 347, 172,\n", + " 1508, 896, 16, 307, 564, 69, 645, 191, 125, 928, 414, 1288,\n", + " 90, 152, 218, 113, 88, 1688, 297, 705, 983, 144, 676, 187,\n", + " 1290, 48, 861, 139, 800, 964, 1584, 217, 140, 449, 500, 441,\n", + " 285, 5, 296, 325, 96, 751, 134, 137, 107, 76, 348, 302,\n", + " 150, 1367, 739, 872, 445, 24, 1209, 14, 326, 148, 36, 2046,\n", + " 1668, 1045, 53, 200, 454, 996, 10, 1069, 504, 630, 958, 1026],\n", + " [ 60, 1160, 1018, 1137, 509, 824, 182, 576, 890, 1076, 569, 769,\n", + " 79, 763, 566, 604, 1862, 286, 691, 189, 1604, 251, 771, 436,\n", + " 361, 715, 328, 1027, 287, 1697, 559, 1014, 582, 239, 299, 23,\n", + " 891, 459, 682, 1600, 520, 1112, 1898, 142, 874, 244, 99, 126,\n", + " 61, 129, 162, 331, 961, 55, 136, 179, 47, 292, 1, 164,\n", + " 453, 532, 7, 934, 42, 1325, 775, 607, 115, 2009, 174, 1187,\n", + " 654, 63, 127, 186, 426, 349, 1168, 309, 1715, 1108, 160, 433,\n", + " 525, 901, 379, 19, 28, 108, 425, 1704, 640, 59, 37, 534,\n", + " 330, 276, 523, 395, 35, 837, 537, 50, 1616, 1140, 641, 487,\n", + " 246, 327, 905, 1872, 242, 202, 684, 1275, 1098, 546, 243, 1841,\n", + " 190, 988, 17, 292, 387, 1128, 120, 652, 356, 716, 376, 1228,\n", + " 494, 32, 252, 1007, 1082, 526, 318, 303, 356, 917, 497, 816,\n", + " 1795, 220, 25, 1454, 595, 1372, 572, 478, 131, 863, 626, 26,\n", + " 492, 758, 756, 1146, 175, 1044, 846, 355, 472, 1136, 973, 64,\n", + " 11, 322, 201, 18, 444, 265, 730, 823, 176, 351, 1037, 205,\n", + " 475, 1985, 367, 1267, 1048, 918, 146, 1195, 510, 966, 372, 124,\n", + " 31, 517, 508, 94, 145, 58, 334, 1557, 93, 293, 13, 67,\n", + " 1193, 1791, 12, 993, 1710, 610, 319, 9, 683, 661, 1941, 697,\n", + " 501, 463, 1273, 401, 476, 3, 1571, 247, 590, 1198, 62, 405,\n", + " 1023, 1141, 1777, 831, 259, 71, 135, 651, 1244, 22, 777, 402,\n", + " 1040, 1303, 1421, 398, 1382, 70, 275, 549, 1707, 733, 737, 1467,\n", + " 790, 257, 621, 80, 1254, 592, 428, 272, 183, 73, 799, 583,\n", + " 506, 785, 210, 248, 896, 56, 49, 184, 21, 1058, 78, 736,\n", + " 412, 543, 175, 714, 1577, 531, 645, 660, 792, 86, 1925, 378,\n", + " 1467, 450, 98, 1074, 159, 653, 188, 1006, 320, 362, 1599, 1508,\n", + " 182, 709, 180, 227, 398, 718, 680, 337, 310, 294, 1649, 413,\n", + " 565, 546, 106, 1292, 68, 237, 290, 6, 797, 4, 701, 245,\n", + " 1320, 377, 912, 847, 670, 15, 1909, 284, 38, 558, 1473, 1375,\n", + " 241, 1473, 77, 281, 620, 360, 312, 437, 262, 416, 435, 796,\n", + " 474, 1250, 2, 1087, 170, 612, 283, 750, 369, 745, 304, 793]])" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ii" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "618b997c-e6a6-4487-b70c-9d260cb556d3", + "metadata": {}, + "outputs": [], + "source": [ + "from torchinfo import summary" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "25759b7b-8deb-4163-b75d-a1357c9fe88f", + "metadata": {}, + "outputs": [], + "source": [ + "summary(net, (2, 1, 224, 224), device=\"cpu\", depth=4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6aa04f07-12d4-4e06-b921-d54367c50a9a", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.7" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} |