summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--notebooks/04-quantizer.ipynb295
1 files changed, 295 insertions, 0 deletions
diff --git a/notebooks/04-quantizer.ipynb b/notebooks/04-quantizer.ipynb
new file mode 100644
index 0000000..66c9c26
--- /dev/null
+++ b/notebooks/04-quantizer.ipynb
@@ -0,0 +1,295 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "7c02ae76-b540-4b16-9492-e9210b3b9249",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "os.environ['CUDA_VISIBLE_DEVICE'] = ''\n",
+ "import random\n",
+ "\n",
+ "%matplotlib inline\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "import numpy as np\n",
+ "from omegaconf import OmegaConf\n",
+ "\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "from importlib.util import find_spec\n",
+ "if find_spec(\"text_recognizer\") is None:\n",
+ " import sys\n",
+ " sys.path.append('..')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "ccdb6dde-47e5-429a-88f2-0764fb7e259a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from hydra import compose, initialize\n",
+ "from omegaconf import OmegaConf\n",
+ "from hydra.utils import instantiate"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "3cf50475-39f2-4642-a7d1-5bcbc0a036f7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "path = \"../training/conf/network/quantizer.yaml\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "e52ecb01-c975-4e55-925d-1182c7aea473",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with open(path, \"rb\") as f:\n",
+ " cfg = OmegaConf.load(f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "f939aa37-7b1d-45cc-885c-323c4540bda1",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'_target_': 'text_recognizer.networks.quantizer.quantizer.VectorQuantizer', 'input_dim': 192, 'codebook': {'_target_': 'text_recognizer.networks.quantizer.codebook.CosineSimilarityCodebook', 'dim': 16, 'codebook_size': 2048, 'kmeans_init': True, 'kmeans_iters': 10, 'decay': 0.8, 'eps': 1e-05, 'threshold_dead': 2}, 'commitment': 1.0}"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "cfg"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "aaeab329-aeb0-4a1b-aa35-5a2aab81b1d0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vq = instantiate(cfg)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "id": "7a0c9f4f-3d95-4722-9212-915a4b9ed096",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "284ec8aa-43a0-4e59-a86f-91bab6c97dca",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t = torch.randn(2, 192, 18, 20)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "id": "d3f6dad0-33f9-4f80-b514-dd0f71c8b93a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "(tt, ii, l) = vq(t)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "id": "80fd228a-d9a2-4334-ab26-c283d215a456",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 192, 18, 20])"
+ ]
+ },
+ "execution_count": 37,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tt.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "id": "22d04553-fd12-4f6c-8a43-9105083b0b82",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 192, 360])"
+ ]
+ },
+ "execution_count": 40,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tt.flatten(start_dim=2).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "id": "b79649ce-3623-4dd2-9a38-bef7bc5c9ac1",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[ 484, 223, 752, 735, 199, 1428, 238, 65, 1357, 950, 1792, 87,\n",
+ " 1006, 264, 1425, 357, 375, 131, 958, 807, 1903, 577, 552, 104,\n",
+ " 278, 495, 578, 1415, 1737, 593, 1442, 269, 62, 1274, 346, 667,\n",
+ " 1296, 421, 455, 1151, 475, 880, 1112, 279, 44, 171, 101, 638,\n",
+ " 719, 1502, 185, 103, 157, 802, 123, 30, 877, 691, 92, 151,\n",
+ " 95, 442, 1782, 819, 466, 360, 1376, 769, 636, 1212, 840, 467,\n",
+ " 181, 933, 1708, 706, 521, 423, 1009, 1337, 527, 382, 431, 789,\n",
+ " 89, 97, 138, 45, 448, 949, 767, 549, 613, 603, 222, 620,\n",
+ " 57, 235, 488, 938, 1163, 1105, 0, 818, 1663, 220, 1785, 132,\n",
+ " 270, 1103, 503, 388, 655, 102, 122, 1914, 52, 166, 361, 40,\n",
+ " 996, 399, 1075, 1340, 857, 256, 366, 773, 324, 110, 820, 1855,\n",
+ " 280, 8, 1000, 836, 1022, 1590, 1121, 1605, 1523, 168, 1135, 91,\n",
+ " 156, 372, 1134, 507, 1307, 344, 545, 537, 401, 611, 479, 581,\n",
+ " 990, 1147, 209, 1229, 211, 612, 1188, 822, 571, 221, 625, 206,\n",
+ " 233, 225, 1472, 81, 440, 1059, 54, 539, 1156, 74, 1902, 704,\n",
+ " 217, 1094, 1189, 1694, 273, 1190, 327, 1629, 947, 27, 1885, 51,\n",
+ " 427, 460, 1065, 83, 100, 642, 29, 795, 366, 1122, 1264, 133,\n",
+ " 300, 28, 1033, 1628, 41, 194, 130, 44, 75, 1525, 229, 374,\n",
+ " 163, 1060, 70, 1472, 619, 1109, 804, 219, 291, 1038, 1269, 408,\n",
+ " 251, 876, 1068, 385, 33, 345, 66, 403, 570, 1556, 1895, 230,\n",
+ " 1450, 498, 105, 249, 386, 711, 1577, 20, 153, 370, 91, 518,\n",
+ " 1097, 109, 688, 966, 1078, 385, 335, 815, 615, 118, 169, 329,\n",
+ " 609, 111, 1580, 785, 149, 368, 263, 43, 1261, 289, 308, 195,\n",
+ " 622, 184, 1342, 660, 208, 321, 432, 226, 430, 979, 347, 172,\n",
+ " 1508, 896, 16, 307, 564, 69, 645, 191, 125, 928, 414, 1288,\n",
+ " 90, 152, 218, 113, 88, 1688, 297, 705, 983, 144, 676, 187,\n",
+ " 1290, 48, 861, 139, 800, 964, 1584, 217, 140, 449, 500, 441,\n",
+ " 285, 5, 296, 325, 96, 751, 134, 137, 107, 76, 348, 302,\n",
+ " 150, 1367, 739, 872, 445, 24, 1209, 14, 326, 148, 36, 2046,\n",
+ " 1668, 1045, 53, 200, 454, 996, 10, 1069, 504, 630, 958, 1026],\n",
+ " [ 60, 1160, 1018, 1137, 509, 824, 182, 576, 890, 1076, 569, 769,\n",
+ " 79, 763, 566, 604, 1862, 286, 691, 189, 1604, 251, 771, 436,\n",
+ " 361, 715, 328, 1027, 287, 1697, 559, 1014, 582, 239, 299, 23,\n",
+ " 891, 459, 682, 1600, 520, 1112, 1898, 142, 874, 244, 99, 126,\n",
+ " 61, 129, 162, 331, 961, 55, 136, 179, 47, 292, 1, 164,\n",
+ " 453, 532, 7, 934, 42, 1325, 775, 607, 115, 2009, 174, 1187,\n",
+ " 654, 63, 127, 186, 426, 349, 1168, 309, 1715, 1108, 160, 433,\n",
+ " 525, 901, 379, 19, 28, 108, 425, 1704, 640, 59, 37, 534,\n",
+ " 330, 276, 523, 395, 35, 837, 537, 50, 1616, 1140, 641, 487,\n",
+ " 246, 327, 905, 1872, 242, 202, 684, 1275, 1098, 546, 243, 1841,\n",
+ " 190, 988, 17, 292, 387, 1128, 120, 652, 356, 716, 376, 1228,\n",
+ " 494, 32, 252, 1007, 1082, 526, 318, 303, 356, 917, 497, 816,\n",
+ " 1795, 220, 25, 1454, 595, 1372, 572, 478, 131, 863, 626, 26,\n",
+ " 492, 758, 756, 1146, 175, 1044, 846, 355, 472, 1136, 973, 64,\n",
+ " 11, 322, 201, 18, 444, 265, 730, 823, 176, 351, 1037, 205,\n",
+ " 475, 1985, 367, 1267, 1048, 918, 146, 1195, 510, 966, 372, 124,\n",
+ " 31, 517, 508, 94, 145, 58, 334, 1557, 93, 293, 13, 67,\n",
+ " 1193, 1791, 12, 993, 1710, 610, 319, 9, 683, 661, 1941, 697,\n",
+ " 501, 463, 1273, 401, 476, 3, 1571, 247, 590, 1198, 62, 405,\n",
+ " 1023, 1141, 1777, 831, 259, 71, 135, 651, 1244, 22, 777, 402,\n",
+ " 1040, 1303, 1421, 398, 1382, 70, 275, 549, 1707, 733, 737, 1467,\n",
+ " 790, 257, 621, 80, 1254, 592, 428, 272, 183, 73, 799, 583,\n",
+ " 506, 785, 210, 248, 896, 56, 49, 184, 21, 1058, 78, 736,\n",
+ " 412, 543, 175, 714, 1577, 531, 645, 660, 792, 86, 1925, 378,\n",
+ " 1467, 450, 98, 1074, 159, 653, 188, 1006, 320, 362, 1599, 1508,\n",
+ " 182, 709, 180, 227, 398, 718, 680, 337, 310, 294, 1649, 413,\n",
+ " 565, 546, 106, 1292, 68, 237, 290, 6, 797, 4, 701, 245,\n",
+ " 1320, 377, 912, 847, 670, 15, 1909, 284, 38, 558, 1473, 1375,\n",
+ " 241, 1473, 77, 281, 620, 360, 312, 437, 262, 416, 435, 796,\n",
+ " 474, 1250, 2, 1087, 170, 612, 283, 750, 369, 745, 304, 793]])"
+ ]
+ },
+ "execution_count": 38,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ii"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "618b997c-e6a6-4487-b70c-9d260cb556d3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from torchinfo import summary"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "25759b7b-8deb-4163-b75d-a1357c9fe88f",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "summary(net, (2, 1, 224, 224), device=\"cpu\", depth=4)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6aa04f07-12d4-4e06-b921-d54367c50a9a",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.7"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}