diff options
-rw-r--r-- | notebooks/04-quantizer.ipynb | 295 |
1 files changed, 0 insertions, 295 deletions
diff --git a/notebooks/04-quantizer.ipynb b/notebooks/04-quantizer.ipynb deleted file mode 100644 index 66c9c26..0000000 --- a/notebooks/04-quantizer.ipynb +++ /dev/null @@ -1,295 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "7c02ae76-b540-4b16-9492-e9210b3b9249", - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "os.environ['CUDA_VISIBLE_DEVICE'] = ''\n", - "import random\n", - "\n", - "%matplotlib inline\n", - "import matplotlib.pyplot as plt\n", - "\n", - "import numpy as np\n", - "from omegaconf import OmegaConf\n", - "\n", - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "from importlib.util import find_spec\n", - "if find_spec(\"text_recognizer\") is None:\n", - " import sys\n", - " sys.path.append('..')" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "ccdb6dde-47e5-429a-88f2-0764fb7e259a", - "metadata": {}, - "outputs": [], - "source": [ - "from hydra import compose, initialize\n", - "from omegaconf import OmegaConf\n", - "from hydra.utils import instantiate" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "3cf50475-39f2-4642-a7d1-5bcbc0a036f7", - "metadata": {}, - "outputs": [], - "source": [ - "path = \"../training/conf/network/quantizer.yaml\"" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "e52ecb01-c975-4e55-925d-1182c7aea473", - "metadata": {}, - "outputs": [], - "source": [ - "with open(path, \"rb\") as f:\n", - " cfg = OmegaConf.load(f)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "f939aa37-7b1d-45cc-885c-323c4540bda1", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'_target_': 'text_recognizer.networks.quantizer.quantizer.VectorQuantizer', 'input_dim': 192, 'codebook': {'_target_': 'text_recognizer.networks.quantizer.codebook.CosineSimilarityCodebook', 'dim': 16, 'codebook_size': 2048, 'kmeans_init': True, 'kmeans_iters': 10, 'decay': 0.8, 'eps': 1e-05, 'threshold_dead': 2}, 'commitment': 1.0}" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "cfg" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "id": "aaeab329-aeb0-4a1b-aa35-5a2aab81b1d0", - "metadata": {}, - "outputs": [], - "source": [ - "vq = instantiate(cfg)" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "7a0c9f4f-3d95-4722-9212-915a4b9ed096", - "metadata": {}, - "outputs": [], - "source": [ - "import torch" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "id": "284ec8aa-43a0-4e59-a86f-91bab6c97dca", - "metadata": {}, - "outputs": [], - "source": [ - "t = torch.randn(2, 192, 18, 20)" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "id": "d3f6dad0-33f9-4f80-b514-dd0f71c8b93a", - "metadata": {}, - "outputs": [], - "source": [ - "(tt, ii, l) = vq(t)" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "id": "80fd228a-d9a2-4334-ab26-c283d215a456", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 192, 18, 20])" - ] - }, - "execution_count": 37, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tt.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "id": "22d04553-fd12-4f6c-8a43-9105083b0b82", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 192, 360])" - ] - }, - "execution_count": 40, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tt.flatten(start_dim=2).shape" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "id": "b79649ce-3623-4dd2-9a38-bef7bc5c9ac1", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[ 484, 223, 752, 735, 199, 1428, 238, 65, 1357, 950, 1792, 87,\n", - " 1006, 264, 1425, 357, 375, 131, 958, 807, 1903, 577, 552, 104,\n", - " 278, 495, 578, 1415, 1737, 593, 1442, 269, 62, 1274, 346, 667,\n", - " 1296, 421, 455, 1151, 475, 880, 1112, 279, 44, 171, 101, 638,\n", - " 719, 1502, 185, 103, 157, 802, 123, 30, 877, 691, 92, 151,\n", - " 95, 442, 1782, 819, 466, 360, 1376, 769, 636, 1212, 840, 467,\n", - " 181, 933, 1708, 706, 521, 423, 1009, 1337, 527, 382, 431, 789,\n", - " 89, 97, 138, 45, 448, 949, 767, 549, 613, 603, 222, 620,\n", - " 57, 235, 488, 938, 1163, 1105, 0, 818, 1663, 220, 1785, 132,\n", - " 270, 1103, 503, 388, 655, 102, 122, 1914, 52, 166, 361, 40,\n", - " 996, 399, 1075, 1340, 857, 256, 366, 773, 324, 110, 820, 1855,\n", - " 280, 8, 1000, 836, 1022, 1590, 1121, 1605, 1523, 168, 1135, 91,\n", - " 156, 372, 1134, 507, 1307, 344, 545, 537, 401, 611, 479, 581,\n", - " 990, 1147, 209, 1229, 211, 612, 1188, 822, 571, 221, 625, 206,\n", - " 233, 225, 1472, 81, 440, 1059, 54, 539, 1156, 74, 1902, 704,\n", - " 217, 1094, 1189, 1694, 273, 1190, 327, 1629, 947, 27, 1885, 51,\n", - " 427, 460, 1065, 83, 100, 642, 29, 795, 366, 1122, 1264, 133,\n", - " 300, 28, 1033, 1628, 41, 194, 130, 44, 75, 1525, 229, 374,\n", - " 163, 1060, 70, 1472, 619, 1109, 804, 219, 291, 1038, 1269, 408,\n", - " 251, 876, 1068, 385, 33, 345, 66, 403, 570, 1556, 1895, 230,\n", - " 1450, 498, 105, 249, 386, 711, 1577, 20, 153, 370, 91, 518,\n", - " 1097, 109, 688, 966, 1078, 385, 335, 815, 615, 118, 169, 329,\n", - " 609, 111, 1580, 785, 149, 368, 263, 43, 1261, 289, 308, 195,\n", - " 622, 184, 1342, 660, 208, 321, 432, 226, 430, 979, 347, 172,\n", - " 1508, 896, 16, 307, 564, 69, 645, 191, 125, 928, 414, 1288,\n", - " 90, 152, 218, 113, 88, 1688, 297, 705, 983, 144, 676, 187,\n", - " 1290, 48, 861, 139, 800, 964, 1584, 217, 140, 449, 500, 441,\n", - " 285, 5, 296, 325, 96, 751, 134, 137, 107, 76, 348, 302,\n", - " 150, 1367, 739, 872, 445, 24, 1209, 14, 326, 148, 36, 2046,\n", - " 1668, 1045, 53, 200, 454, 996, 10, 1069, 504, 630, 958, 1026],\n", - " [ 60, 1160, 1018, 1137, 509, 824, 182, 576, 890, 1076, 569, 769,\n", - " 79, 763, 566, 604, 1862, 286, 691, 189, 1604, 251, 771, 436,\n", - " 361, 715, 328, 1027, 287, 1697, 559, 1014, 582, 239, 299, 23,\n", - " 891, 459, 682, 1600, 520, 1112, 1898, 142, 874, 244, 99, 126,\n", - " 61, 129, 162, 331, 961, 55, 136, 179, 47, 292, 1, 164,\n", - " 453, 532, 7, 934, 42, 1325, 775, 607, 115, 2009, 174, 1187,\n", - " 654, 63, 127, 186, 426, 349, 1168, 309, 1715, 1108, 160, 433,\n", - " 525, 901, 379, 19, 28, 108, 425, 1704, 640, 59, 37, 534,\n", - " 330, 276, 523, 395, 35, 837, 537, 50, 1616, 1140, 641, 487,\n", - " 246, 327, 905, 1872, 242, 202, 684, 1275, 1098, 546, 243, 1841,\n", - " 190, 988, 17, 292, 387, 1128, 120, 652, 356, 716, 376, 1228,\n", - " 494, 32, 252, 1007, 1082, 526, 318, 303, 356, 917, 497, 816,\n", - " 1795, 220, 25, 1454, 595, 1372, 572, 478, 131, 863, 626, 26,\n", - " 492, 758, 756, 1146, 175, 1044, 846, 355, 472, 1136, 973, 64,\n", - " 11, 322, 201, 18, 444, 265, 730, 823, 176, 351, 1037, 205,\n", - " 475, 1985, 367, 1267, 1048, 918, 146, 1195, 510, 966, 372, 124,\n", - " 31, 517, 508, 94, 145, 58, 334, 1557, 93, 293, 13, 67,\n", - " 1193, 1791, 12, 993, 1710, 610, 319, 9, 683, 661, 1941, 697,\n", - " 501, 463, 1273, 401, 476, 3, 1571, 247, 590, 1198, 62, 405,\n", - " 1023, 1141, 1777, 831, 259, 71, 135, 651, 1244, 22, 777, 402,\n", - " 1040, 1303, 1421, 398, 1382, 70, 275, 549, 1707, 733, 737, 1467,\n", - " 790, 257, 621, 80, 1254, 592, 428, 272, 183, 73, 799, 583,\n", - " 506, 785, 210, 248, 896, 56, 49, 184, 21, 1058, 78, 736,\n", - " 412, 543, 175, 714, 1577, 531, 645, 660, 792, 86, 1925, 378,\n", - " 1467, 450, 98, 1074, 159, 653, 188, 1006, 320, 362, 1599, 1508,\n", - " 182, 709, 180, 227, 398, 718, 680, 337, 310, 294, 1649, 413,\n", - " 565, 546, 106, 1292, 68, 237, 290, 6, 797, 4, 701, 245,\n", - " 1320, 377, 912, 847, 670, 15, 1909, 284, 38, 558, 1473, 1375,\n", - " 241, 1473, 77, 281, 620, 360, 312, 437, 262, 416, 435, 796,\n", - " 474, 1250, 2, 1087, 170, 612, 283, 750, 369, 745, 304, 793]])" - ] - }, - "execution_count": 38, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "ii" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "618b997c-e6a6-4487-b70c-9d260cb556d3", - "metadata": {}, - "outputs": [], - "source": [ - "from torchinfo import summary" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "25759b7b-8deb-4163-b75d-a1357c9fe88f", - "metadata": {}, - "outputs": [], - "source": [ - "summary(net, (2, 1, 224, 224), device=\"cpu\", depth=4)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6aa04f07-12d4-4e06-b921-d54367c50a9a", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.7" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} |