diff options
Diffstat (limited to 'src')
26 files changed, 418 insertions, 920 deletions
diff --git a/src/notebooks/00-testing-stuff-out.ipynb b/src/notebooks/00-testing-stuff-out.ipynb index 9d265ba..0294394 100644 --- a/src/notebooks/00-testing-stuff-out.ipynb +++ b/src/notebooks/00-testing-stuff-out.ipynb @@ -22,36 +22,94 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ - "from text_recognizer.networks.residual_network import IdentityBlock, ResidualBlock, BasicBlock, BottleNeckBlock, ResidualLayer, Encoder, ResidualNetwork" + "from text_recognizer.networks.residual_network import IdentityBlock, ResidualBlock, BasicBlock, BottleNeckBlock, ResidualLayer, ResidualNetwork" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "IdentityBlock(\n", + " (blocks): Identity()\n", + " (activation_fn): ReLU(inplace=True)\n", + " (shortcut): Identity()\n", + ")" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "IdentityBlock(32, 64)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "ResidualBlock(\n", + " (blocks): Identity()\n", + " (activation_fn): ReLU(inplace=True)\n", + " (shortcut): Sequential(\n", + " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "ResidualBlock(32, 64)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BasicBlock(\n", + " (blocks): Sequential(\n", + " (0): Sequential(\n", + " (0): Conv2dAuto(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (1): ReLU(inplace=True)\n", + " (2): Sequential(\n", + " (0): Conv2dAuto(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (activation_fn): ReLU(inplace=True)\n", + " (shortcut): Sequential(\n", + " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + ")\n" + ] + } + ], "source": [ "dummy = torch.ones((1, 32, 224, 224))\n", "\n", @@ -62,9 +120,39 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "BottleNeckBlock(\n", + " (blocks): Sequential(\n", + " (0): Sequential(\n", + " (0): Conv2dAuto(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (1): ReLU(inplace=True)\n", + " (2): Sequential(\n", + " (0): Conv2dAuto(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " (3): ReLU(inplace=True)\n", + " (4): Sequential(\n", + " (0): Conv2dAuto(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + " )\n", + " (activation_fn): ReLU(inplace=True)\n", + " (shortcut): Sequential(\n", + " (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n", + " )\n", + ")\n" + ] + } + ], "source": [ "dummy = torch.ones((1, 32, 10, 10))\n", "\n", @@ -191,7 +279,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -200,7 +288,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ @@ -218,7 +306,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -227,7 +315,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -505,7 +593,7 @@ "===============================================================================================" ] }, - "execution_count": 8, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } diff --git a/src/notebooks/04a-look-at-iam-lines.ipynb b/src/notebooks/04a-look-at-iam-lines.ipynb index 093920a..d64b391 100644 --- a/src/notebooks/04a-look-at-iam-lines.ipynb +++ b/src/notebooks/04a-look-at-iam-lines.ipynb @@ -32,7 +32,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -40,8 +40,8 @@ "output_type": "stream", "text": [ "IAM Lines Dataset\n", - "Number classes: 81\n", - "Mapping: {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: 'A', 11: 'B', 12: 'C', 13: 'D', 14: 'E', 15: 'F', 16: 'G', 17: 'H', 18: 'I', 19: 'J', 20: 'K', 21: 'L', 22: 'M', 23: 'N', 24: 'O', 25: 'P', 26: 'Q', 27: 'R', 28: 'S', 29: 'T', 30: 'U', 31: 'V', 32: 'W', 33: 'X', 34: 'Y', 35: 'Z', 36: 'a', 37: 'b', 38: 'c', 39: 'd', 40: 'e', 41: 'f', 42: 'g', 43: 'h', 44: 'i', 45: 'j', 46: 'k', 47: 'l', 48: 'm', 49: 'n', 50: 'o', 51: 'p', 52: 'q', 53: 'r', 54: 's', 55: 't', 56: 'u', 57: 'v', 58: 'w', 59: 'x', 60: 'y', 61: 'z', 62: ' ', 63: '!', 64: '\"', 65: '#', 66: '&', 67: \"'\", 68: '(', 69: ')', 70: '*', 71: '+', 72: ',', 73: '-', 74: '.', 75: '/', 76: ':', 77: ';', 78: '?', 79: '_', 80: '<blank>'}\n", + "Number classes: 80\n", + "Mapping: {0: '0', 1: '1', 2: '2', 3: '3', 4: '4', 5: '5', 6: '6', 7: '7', 8: '8', 9: '9', 10: 'A', 11: 'B', 12: 'C', 13: 'D', 14: 'E', 15: 'F', 16: 'G', 17: 'H', 18: 'I', 19: 'J', 20: 'K', 21: 'L', 22: 'M', 23: 'N', 24: 'O', 25: 'P', 26: 'Q', 27: 'R', 28: 'S', 29: 'T', 30: 'U', 31: 'V', 32: 'W', 33: 'X', 34: 'Y', 35: 'Z', 36: 'a', 37: 'b', 38: 'c', 39: 'd', 40: 'e', 41: 'f', 42: 'g', 43: 'h', 44: 'i', 45: 'j', 46: 'k', 47: 'l', 48: 'm', 49: 'n', 50: 'o', 51: 'p', 52: 'q', 53: 'r', 54: 's', 55: 't', 56: 'u', 57: 'v', 58: 'w', 59: 'x', 60: 'y', 61: 'z', 62: ' ', 63: '!', 64: '\"', 65: '#', 66: '&', 67: \"'\", 68: '(', 69: ')', 70: '*', 71: '+', 72: ',', 73: '-', 74: '.', 75: '/', 76: ':', 77: ';', 78: '?', 79: '_'}\n", "Data: (7101, 28, 952)\n", "Targets: (7101, 97)\n", "\n" @@ -56,16 +56,16 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "(97, 81)" + "(97, 80)" ] }, - "execution_count": 55, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -76,7 +76,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -85,7 +85,7 @@ "'A MOVE to stop Mr. Gaitskell from'" ] }, - "execution_count": 56, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -99,7 +99,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -251,620 +251,116 @@ }, { "cell_type": "code", - "execution_count": 106, + "execution_count": 46, "metadata": {}, "outputs": [], "source": [ - "data, target = dataset[10]" + "data, target = dataset[0]\n", + "sentence = convert_y_label_to_string(target) " ] }, { "cell_type": "code", - "execution_count": 107, + "execution_count": 47, "metadata": {}, "outputs": [], "source": [ - "text = convert_y_label_to_string(dataset.targets[10])" + "h, w, s = 28, 18, 4" ] }, { "cell_type": "code", - "execution_count": 108, + "execution_count": 48, "metadata": {}, "outputs": [], "source": [ - "data1, target1 = dataset[110]" - ] - }, - { - "cell_type": "code", - "execution_count": 60, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "\"Griffiths resolution. Mr. Foot's line will\"" - ] - }, - "execution_count": 60, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "text" - ] - }, - { - "cell_type": "code", - "execution_count": 89, - "metadata": {}, - "outputs": [], - "source": [ - "S = 30\n", - "S_min = 10\n", - "N = 16\n", - "C = 20" - ] - }, - { - "cell_type": "code", - "execution_count": 92, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([[ 6, 10, 2, 17, 8, 18, 16, 4, 14, 14, 6, 9, 4, 14, 10, 10, 12, 4,\n", - " 11, 18, 19, 4, 10, 8, 13, 2, 18, 4, 17, 9],\n", - " [ 8, 13, 8, 5, 5, 6, 4, 10, 9, 14, 19, 6, 7, 6, 10, 6, 5, 3,\n", - " 14, 10, 1, 18, 3, 3, 13, 13, 16, 12, 5, 6],\n", - " [ 1, 9, 18, 10, 10, 10, 10, 6, 7, 7, 14, 6, 12, 12, 3, 9, 14, 16,\n", - " 11, 14, 3, 10, 9, 15, 19, 8, 13, 5, 12, 15],\n", - " [10, 2, 6, 11, 14, 1, 13, 1, 7, 2, 19, 1, 1, 17, 6, 16, 18, 12,\n", - " 3, 18, 19, 17, 9, 12, 14, 15, 3, 8, 1, 9],\n", - " [12, 7, 14, 5, 2, 12, 1, 16, 9, 16, 18, 17, 6, 11, 2, 7, 5, 8,\n", - " 16, 6, 19, 13, 12, 17, 11, 13, 17, 12, 5, 1],\n", - " [13, 11, 14, 18, 15, 8, 17, 13, 18, 5, 10, 6, 15, 3, 4, 11, 12, 5,\n", - " 4, 1, 17, 12, 7, 5, 5, 9, 19, 15, 4, 5],\n", - " [12, 1, 4, 5, 6, 13, 19, 1, 1, 15, 3, 14, 8, 19, 7, 5, 19, 9,\n", - " 5, 11, 14, 10, 11, 1, 12, 19, 14, 13, 19, 15],\n", - " [16, 3, 7, 10, 12, 15, 9, 18, 9, 10, 16, 2, 14, 4, 18, 3, 8, 12,\n", - " 16, 19, 11, 5, 9, 19, 11, 14, 7, 16, 4, 3],\n", - " [ 1, 4, 2, 10, 10, 2, 3, 5, 9, 6, 16, 3, 11, 6, 14, 19, 3, 11,\n", - " 6, 3, 19, 17, 2, 12, 12, 4, 5, 15, 19, 1],\n", - " [18, 9, 12, 1, 1, 1, 3, 8, 9, 7, 9, 19, 5, 12, 10, 17, 15, 14,\n", - " 18, 15, 8, 4, 1, 3, 14, 14, 2, 5, 9, 4],\n", - " [15, 5, 15, 14, 8, 9, 8, 15, 12, 15, 18, 2, 7, 13, 19, 12, 1, 16,\n", - " 12, 11, 1, 12, 17, 16, 18, 3, 6, 11, 9, 11],\n", - " [18, 17, 14, 11, 5, 15, 3, 10, 10, 16, 14, 9, 12, 7, 8, 13, 18, 11,\n", - " 6, 9, 16, 10, 14, 6, 5, 19, 11, 13, 7, 14],\n", - " [18, 5, 11, 13, 15, 9, 7, 10, 3, 19, 8, 10, 13, 4, 11, 5, 14, 17,\n", - " 2, 16, 18, 8, 7, 16, 15, 19, 8, 13, 13, 9],\n", - " [14, 14, 4, 10, 6, 14, 14, 1, 12, 1, 3, 6, 7, 6, 19, 9, 2, 19,\n", - " 13, 9, 1, 2, 11, 2, 2, 10, 3, 11, 1, 15],\n", - " [ 6, 16, 19, 14, 14, 17, 18, 10, 18, 18, 2, 19, 15, 15, 1, 3, 12, 11,\n", - " 17, 7, 15, 6, 10, 2, 13, 17, 13, 6, 8, 14],\n", - " [ 9, 19, 9, 13, 12, 9, 5, 18, 4, 7, 14, 14, 12, 1, 16, 19, 16, 2,\n", - " 15, 4, 14, 9, 15, 2, 9, 13, 12, 1, 12, 11]])" - ] - }, - "execution_count": 92, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)" - ] - }, - { - "cell_type": "code", - "execution_count": 94, - "metadata": {}, - "outputs": [], - "source": [ - "T = 50 # Input sequence length" - ] - }, - { - "cell_type": "code", - "execution_count": 125, - "metadata": {}, - "outputs": [], - "source": [ - "target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)\n", - "target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long)" - ] - }, - { - "cell_type": "code", - "execution_count": 129, - "metadata": {}, - "outputs": [], - "source": [ - "input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)" - ] - }, - { - "cell_type": "code", - "execution_count": 130, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50, 50])" - ] - }, - "execution_count": 130, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "input_lengths" - ] - }, - { - "cell_type": "code", - "execution_count": 126, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([20, 28, 20, 32, 30, 34, 14, 15, 21, 3, 20, 13, 28, 40, 15, 27])" - ] - }, - "execution_count": 126, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "target_lengths" - ] - }, - { - "cell_type": "code", - "execution_count": 128, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([17, 17, 5, 5, 15, 5, 16, 19, 12, 3, 6, 15, 6, 13, 10, 1, 19, 12,\n", - " 2, 13, 19, 13, 7, 4, 19, 9, 19, 1, 3, 16, 1, 12, 11, 8, 17, 6,\n", - " 10, 8, 15, 15, 18, 11, 2, 6, 17, 8, 1, 12, 3, 15, 10, 14, 3, 3,\n", - " 17, 4, 18, 15, 13, 18, 19, 12, 1, 17, 18, 9, 10, 8, 2, 3, 9, 4,\n", - " 7, 9, 12, 11, 9, 5, 12, 10, 4, 15, 8, 6, 17, 9, 7, 7, 18, 15,\n", - " 16, 16, 14, 17, 11, 14, 13, 9, 10, 19, 7, 13, 12, 5, 19, 3, 7, 18,\n", - " 7, 6, 5, 1, 6, 11, 1, 19, 18, 15, 6, 4, 13, 14, 12, 19, 18, 4,\n", - " 15, 14, 12, 1, 14, 18, 1, 4, 1, 7, 12, 6, 3, 9, 8, 19, 7, 13,\n", - " 1, 4, 14, 1, 14, 8, 19, 2, 6, 11, 19, 11, 3, 13, 14, 17, 3, 3,\n", - " 10, 10, 18, 2, 11, 10, 8, 2, 18, 9, 2, 1, 16, 2, 5, 9, 1, 4,\n", - " 16, 18, 12, 11, 12, 13, 13, 18, 2, 3, 2, 7, 18, 8, 2, 16, 12, 18,\n", - " 10, 15, 16, 12, 3, 5, 6, 2, 14, 3, 10, 2, 12, 14, 3, 14, 11, 14,\n", - " 6, 11, 5, 4, 6, 9, 17, 1, 7, 1, 6, 13, 7, 7, 2, 4, 4, 15,\n", - " 1, 11, 10, 12, 10, 4, 3, 3, 7, 19, 5, 19, 18, 17, 1, 11, 14, 12,\n", - " 18, 16, 17, 16, 12, 5, 5, 5, 5, 4, 12, 18, 1, 16, 3, 12, 9, 8,\n", - " 13, 18, 6, 7, 17, 1, 9, 2, 17, 10, 2, 3, 8, 11, 3, 3, 17, 11,\n", - " 17, 19, 6, 12, 4, 11, 18, 3, 18, 16, 7, 9, 6, 15, 10, 17, 1, 17,\n", - " 2, 7, 7, 3, 7, 5, 15, 8, 3, 15, 6, 4, 16, 12, 4, 11, 1, 15,\n", - " 14, 4, 14, 17, 16, 14, 18, 10, 17, 4, 2, 19, 8, 10, 11, 2, 18, 2,\n", - " 19, 10, 15, 15, 14, 10, 3, 16, 14, 5, 15, 4, 7, 13, 16, 14, 3, 14])" - ] - }, - "execution_count": 128, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "target" - ] - }, - { - "cell_type": "code", - "execution_count": 111, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([16, 53, 44, 41, 41, 44, 55, 43, 54, 62, 53, 40, 54, 50, 47, 56, 55, 44,\n", - " 50, 49, 74, 62, 22, 53, 74, 62, 15, 50, 50, 55, 67, 54, 62, 47, 44, 49,\n", - " 40, 62, 58, 44, 47, 47], dtype=torch.uint8)" - ] - }, - "execution_count": 111, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "target[target < 79]" - ] - }, - { - "cell_type": "code", - "execution_count": 122, - "metadata": {}, - "outputs": [], - "source": [ - "ts = torch.stack([target, target1])\n", - "\n", - "targets = torch.Tensor([])\n", - "target_lengths = []\n", - "for t in ts:\n", - " t = t[t < 79]\n", - " targets = torch.cat([targets, t])\n", - " target_lengths.append(len(t))\n", - "\n", - "targets = targets.type(dtype=torch.long)\n", - "target_lengths = torch.Tensor(target_lengths).type(dtype=torch.long)" - ] - }, - { - "cell_type": "code", - "execution_count": 123, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([42, 41])" - ] - }, - "execution_count": 123, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "target_lengths" - ] - }, - { - "cell_type": "code", - "execution_count": 124, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([16, 53, 44, 41, 41, 44, 55, 43, 54, 62, 53, 40, 54, 50, 47, 56, 55, 44,\n", - " 50, 49, 74, 62, 22, 53, 74, 62, 15, 50, 50, 55, 67, 54, 62, 47, 44, 49,\n", - " 40, 62, 58, 44, 47, 47, 47, 44, 54, 55, 40, 39, 62, 37, 60, 62, 55, 43,\n", - " 40, 62, 16, 50, 57, 40, 53, 49, 48, 40, 49, 55, 74, 62, 18, 48, 48, 40,\n", - " 39, 44, 36, 55, 40, 47, 60, 62, 22, 53, 74])" - ] - }, - "execution_count": 124, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "targets" - ] - }, - { - "cell_type": "code", - "execution_count": 115, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([42., 41.])" - ] - }, - "execution_count": 115, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "target_lengths" - ] - }, - { - "cell_type": "code", - "execution_count": 99, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([430])" - ] - }, - "execution_count": 99, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "target.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([10, 22, 19, 13, 15, 23, 28, 14, 22, 21, 16, 14, 22, 28, 22, 26])" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)" + "from einops.layers.torch import Rearrange\n", + "slide = nn.Sequential(nn.Unfold(kernel_size=(h, w), stride=(1, s)), Rearrange(\"b (c h w) t -> b t c h w\", h=h, w=w, c=1))" ] }, { "cell_type": "code", - "execution_count": 77, + "execution_count": 49, "metadata": {}, "outputs": [], "source": [ - "targets = torch.stack([target, target])" + "patches = slide(data.unsqueeze(0))" ] }, { "cell_type": "code", - "execution_count": 83, + "execution_count": 50, "metadata": {}, "outputs": [ { - "ename": "RuntimeError", - "evalue": "stack expects each tensor to be equal size, but got [0] at entry 0 and [42] at entry 1", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m<ipython-input-83-692d40f2bb6a>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtargets\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mt\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mt\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0;36m79\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mRuntimeError\u001b[0m: stack expects each tensor to be equal size, but got [0] at entry 0 and [42] at entry 1" + "name": "stdout", + "output_type": "stream", + "text": [ + "A MOVE to stop Mr. Gaitskell from\n" ] - } - ], - "source": [ - "ts = torch.Tensor()\n", - "for i, t in enumerate(targets):\n", - " torch.stack([ts, t[t < 79]])" - ] - }, - { - "cell_type": "code", - "execution_count": 78, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([16, 53, 44, 41, 41, 44, 55, 43, 54, 62, 53, 40, 54, 50, 47, 56, 55, 44,\n", - " 50, 49, 74, 62, 22, 53, 74, 62, 15, 50, 50, 55, 67, 54, 62, 47, 44, 49,\n", - " 40, 62, 58, 44, 47, 47, 16, 53, 44, 41, 41, 44, 55, 43, 54, 62, 53, 40,\n", - " 54, 50, 47, 56, 55, 44, 50, 49, 74, 62, 22, 53, 74, 62, 15, 50, 50, 55,\n", - " 67, 54, 62, 47, 44, 49, 40, 62, 58, 44, 47, 47], dtype=torch.uint8)" - ] - }, - "execution_count": 78, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "targets[targets<79]" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks import LineRecurrentNetwork" - ] - }, - { - "cell_type": "code", - "execution_count": 52, - "metadata": {}, - "outputs": [], - "source": [ - "crnn = LineRecurrentNetwork(encoder=\"ResidualNetworkEncoder\",\n", - " \n", - " encoder_args={\n", - " \"in_channels\": 1,\n", - " \"num_classes\": 80,\n", - " \"depths\": [2, 2],\n", - " \"block_sizes\": [64, 128],\n", - " \"activation\": \"leaky_relu\",\n", - " \"stn\": False,},\n", - " patch_size=[28, 14], \n", - " stride=[1, 6])" - ] - }, - { - "cell_type": "code", - "execution_count": 53, - "metadata": {}, - "outputs": [], - "source": [ - "output = crnn(data)" - ] - }, - { - "cell_type": "code", - "execution_count": 54, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([157, 1, 80])" - ] - }, - "execution_count": 54, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "output.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 55, - "metadata": {}, - "outputs": [], - "source": [ - "# output = output.unsqueeze(0)\n", - "targets = target.unsqueeze(0).type(torch.long)" - ] - }, - { - "cell_type": "code", - "execution_count": 56, - "metadata": {}, - "outputs": [], - "source": [ - "input_lengths = torch.full(\n", - " size=(output.shape[1],), fill_value=output.shape[0], dtype=torch.long,\n", - ")\n", - "target_lengths = torch.full(\n", - " size=(output.shape[1],), fill_value=targets.shape[1], dtype=torch.long,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": 57, - "metadata": {}, - "outputs": [], - "source": [ - "ctc = nn.CTCLoss(blank=0)" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "metadata": {}, - "outputs": [ + }, { "data": { + "image/png": "\n", "text/plain": [ - "torch.Size([1, 97])" + "<Figure size 1440x1440 with 60 Axes>" ] }, - "execution_count": 58, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "targets.shape" + "# remove batch size\n", + "n = 60\n", + "patches = patches.squeeze(0)\n", + "fig = plt.figure(figsize=(20, 20))\n", + "print(sentence)\n", + "for i in range(n):\n", + " ax = fig.add_subplot(1, n, i + 1)\n", + " ax.imshow(patches[i].squeeze(0), cmap='gray')\n", + " ax.set_xticks([])\n", + " ax.set_yticks([])" ] }, { "cell_type": "code", - "execution_count": 59, + "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "torch.Size([157, 1, 80])" + "torch.Size([234, 1, 28, 18])" ] }, - "execution_count": 59, + "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "output.shape" + "patches.shape" ] }, { "cell_type": "code", - "execution_count": 60, + "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "tensor(6.9447, grad_fn=<MeanBackward0>)" + "24.0" ] }, - "execution_count": 60, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "ctc(output, targets, input_lengths, target_lengths)" - ] - }, - { - "cell_type": "code", - "execution_count": 64, - "metadata": {}, - "outputs": [], - "source": [ - "from einops.layers.torch import Rearrange\n", - "slide = nn.Sequential(nn.Unfold(kernel_size=(28, 14), stride=(1, 5)), Rearrange(\"b (c h w) t -> b t c h w\", h=28, w=14, c=1))" - ] - }, - { - "cell_type": "code", - "execution_count": 65, - "metadata": {}, - "outputs": [], - "source": [ - "patches = slide(data.unsqueeze(0))" - ] - }, - { - "cell_type": "code", - "execution_count": 68, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "\n", - "text/plain": [ - "<Figure size 1440x1440 with 6 Axes>" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "# remove batch size\n", - "patches = patches.squeeze(0)\n", - "fig = plt.figure(figsize=(20, 20))\n", - "for i in range(6):\n", - " ax = fig.add_subplot(1, 6, i + 1)\n", - " ax.imshow(patches[i].squeeze(0), cmap='gray')" + "32 * 0.75" ] }, { diff --git a/src/tasks/prepare_experiments.sh b/src/tasks/prepare_experiments.sh new file mode 100755 index 0000000..9b91daa --- /dev/null +++ b/src/tasks/prepare_experiments.sh @@ -0,0 +1,3 @@ +#!/bin/bash +experiments_filename=${1:-training/experiments/sample_experiment.yml} +python training/prepare_experiments.py --experiments_filename $experiments_filename diff --git a/src/tasks/prepare_sample_experiments.sh b/src/tasks/prepare_sample_experiments.sh deleted file mode 100755 index bc34f48..0000000 --- a/src/tasks/prepare_sample_experiments.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -python training/prepare_experiments.py --experiments_filename training/experiments/sample_experiment.yml diff --git a/src/tasks/train_crnn_line_ctc_model.sh b/src/tasks/train_crnn_line_ctc_model.sh new file mode 100644 index 0000000..9831289 --- /dev/null +++ b/src/tasks/train_crnn_line_ctc_model.sh @@ -0,0 +1,3 @@ +#!/bin/bash +experiments_filename=${1:-training/experiments/line_ctc_experiment.yml} +exec ./prepare_experiments.sh experiments_filename=experiments_filename diff --git a/src/text_recognizer/character_predictor.py b/src/text_recognizer/character_predictor.py index b733a53..df37e68 100644 --- a/src/text_recognizer/character_predictor.py +++ b/src/text_recognizer/character_predictor.py @@ -15,6 +15,7 @@ class CharacterPredictor: """Intializes the CharacterModel and load the pretrained weights.""" self.model = CharacterModel(network_fn=network_fn) self.model.eval() + self.model.use_swa_model() def predict(self, image_or_filename: Union[np.ndarray, str]) -> Tuple[str, float]: """Predict on a single images contianing a handwritten character.""" diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py index d23fe56..caf8065 100644 --- a/src/text_recognizer/models/base.py +++ b/src/text_recognizer/models/base.py @@ -77,9 +77,9 @@ class Model(ABC): # Stochastic Weight Averaging placeholders. self.swa_args = swa_args - self._swa_start = None self._swa_scheduler = None self._swa_network = None + self._use_swa_model = False # Experiment directory. self.model_dir = None @@ -220,15 +220,24 @@ class Model(ABC): if self._optimizer and self._lr_scheduler is not None: if "OneCycleLR" in str(self._lr_scheduler): self.lr_scheduler_args["steps_per_epoch"] = len(self.train_dataloader()) - self._lr_scheduler = self._lr_scheduler( - self._optimizer, **self.lr_scheduler_args - ) - else: - self._lr_scheduler = None + + # Assume lr scheduler should update at each epoch if not specified. + if "interval" not in self.lr_scheduler_args: + interval = "epoch" + else: + interval = self.lr_scheduler_args.pop("interval") + self._lr_scheduler = { + "lr_scheduler": self._lr_scheduler( + self._optimizer, **self.lr_scheduler_args + ), + "interval": interval, + } if self.swa_args is not None: - self._swa_start = self.swa_args["start"] - self._swa_scheduler = SWALR(self._optimizer, swa_lr=self.swa_args["lr"]) + self._swa_scheduler = { + "swa_scheduler": SWALR(self._optimizer, swa_lr=self.swa_args["lr"]), + "swa_start": self.swa_args["start"], + } self._swa_network = AveragedModel(self._network).to(self.device) @property @@ -280,21 +289,16 @@ class Model(ABC): return self._optimizer @property - def lr_scheduler(self) -> Optional[Callable]: - """Learning rate scheduler.""" + def lr_scheduler(self) -> Optional[Dict]: + """Returns a directory with the learning rate scheduler.""" return self._lr_scheduler @property - def swa_scheduler(self) -> Optional[Callable]: - """Returns the stochastic weight averaging scheduler.""" + def swa_scheduler(self) -> Optional[Dict]: + """Returns a directory with the stochastic weight averaging scheduler.""" return self._swa_scheduler @property - def swa_start(self) -> Optional[Callable]: - """Returns the start epoch of stochastic weight averaging.""" - return self._swa_start - - @property def swa_network(self) -> Optional[Callable]: """Returns the stochastic weight averaging network.""" return self._swa_network @@ -311,20 +315,32 @@ class Model(ABC): WEIGHT_DIRNAME.mkdir(parents=True, exist_ok=True) return str(WEIGHT_DIRNAME / f"{self._name}_weights.pt") + def use_swa_model(self) -> None: + """Set to use predictions from SWA model.""" + if self.swa_network is not None: + self._use_swa_model = True + + def forward(self, x: Tensor) -> Tensor: + """Feedforward pass with the network.""" + if self._use_swa_model: + return self.swa_network(x) + else: + return self.network(x) + def loss_fn(self, output: Tensor, targets: Tensor) -> Tensor: """Compute the loss.""" return self.criterion(output, targets) def summary( - self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 5 + self, input_shape: Optional[Tuple[int, int, int]] = None, depth: int = 3 ) -> None: """Prints a summary of the network architecture.""" if input_shape is not None: - summary(self._network, input_shape, depth=depth, device=self.device) + summary(self.network, input_shape, depth=depth, device=self.device) elif self._input_shape is not None: input_shape = (1,) + tuple(self._input_shape) - summary(self._network, input_shape, depth=depth, device=self.device) + summary(self.network, input_shape, depth=depth, device=self.device) else: logger.warning("Could not print summary as input shape is not set.") diff --git a/src/text_recognizer/models/character_model.py b/src/text_recognizer/models/character_model.py index 64ba693..50e94a2 100644 --- a/src/text_recognizer/models/character_model.py +++ b/src/text_recognizer/models/character_model.py @@ -75,11 +75,7 @@ class CharacterModel(Model): # Put the image tensor on the device the model weights are on. image = image.to(self.device) - logits = ( - self.swa_network(image) - if self.swa_network is not None - else self.network(image) - ) + logits = self.forward(image) prediction = self.softmax(logits.squeeze(0)) diff --git a/src/text_recognizer/models/line_ctc_model.py b/src/text_recognizer/models/line_ctc_model.py index af41f18..16eaed3 100644 --- a/src/text_recognizer/models/line_ctc_model.py +++ b/src/text_recognizer/models/line_ctc_model.py @@ -98,16 +98,12 @@ class LineCTCModel(Model): # Put the image tensor on the device the model weights are on. image = image.to(self.device) - log_probs = ( - self.swa_network(image) - if self.swa_network is not None - else self.network(image) - ) + log_probs = self.forward(image) raw_pred, _ = greedy_decoder( predictions=log_probs, character_mapper=self.mapper, - blank_label=80, + blank_label=79, collapse_repeated=True, ) diff --git a/src/text_recognizer/networks/__init__.py b/src/text_recognizer/networks/__init__.py index d20c86a..a39975f 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/src/text_recognizer/networks/__init__.py @@ -2,12 +2,14 @@ from .ctc import greedy_decoder from .lenet import LeNet from .line_lstm_ctc import LineRecurrentNetwork +from .losses import EmbeddingLoss from .misc import sliding_window from .mlp import MLP from .residual_network import ResidualNetwork, ResidualNetworkEncoder from .wide_resnet import WideResidualNetwork __all__ = [ + "EmbeddingLoss", "greedy_decoder", "MLP", "LeNet", diff --git a/src/text_recognizer/networks/line_lstm_ctc.py b/src/text_recognizer/networks/line_lstm_ctc.py index 5c57479..9009f94 100644 --- a/src/text_recognizer/networks/line_lstm_ctc.py +++ b/src/text_recognizer/networks/line_lstm_ctc.py @@ -1,9 +1,11 @@ """LSTM with CTC for handwritten text recognition within a line.""" import importlib +from pathlib import Path from typing import Callable, Dict, List, Optional, Tuple, Type, Union from einops import rearrange, reduce from einops.layers.torch import Rearrange, Reduce +from loguru import logger import torch from torch import nn from torch import Tensor @@ -14,40 +16,72 @@ class LineRecurrentNetwork(nn.Module): def __init__( self, - encoder: str, - encoder_args: Dict = None, + backbone: str, + backbone_args: Dict = None, flatten: bool = True, input_size: int = 128, hidden_size: int = 128, + bidirectional: bool = False, num_layers: int = 1, num_classes: int = 80, patch_size: Tuple[int, int] = (28, 28), stride: Tuple[int, int] = (1, 14), ) -> None: super().__init__() - self.encoder_args = encoder_args or {} + self.backbone_args = backbone_args or {} self.patch_size = patch_size self.stride = stride self.sliding_window = self._configure_sliding_window() self.input_size = input_size self.hidden_size = hidden_size - self.encoder = self._configure_encoder(encoder) + self.backbone = self._configure_backbone(backbone) + self.bidirectional = bidirectional self.flatten = flatten - self.fc = nn.Linear(in_features=self.input_size, out_features=self.hidden_size) + + if self.flatten: + self.fc = nn.Linear( + in_features=self.input_size, out_features=self.hidden_size + ) + self.rnn = nn.LSTM( input_size=self.hidden_size, hidden_size=self.hidden_size, + bidirectional=bidirectional, num_layers=num_layers, ) + + decoder_size = self.hidden_size * 2 if self.bidirectional else self.hidden_size + self.decoder = nn.Sequential( - nn.Linear(in_features=self.hidden_size, out_features=num_classes), + nn.Linear(in_features=decoder_size, out_features=num_classes), nn.LogSoftmax(dim=2), ) - def _configure_encoder(self, encoder: str) -> Type[nn.Module]: + def _configure_backbone(self, backbone: str) -> Type[nn.Module]: network_module = importlib.import_module("text_recognizer.networks") - encoder_ = getattr(network_module, encoder) - return encoder_(**self.encoder_args) + backbone_ = getattr(network_module, backbone) + + if "pretrained" in self.backbone_args: + logger.info("Loading pretrained backbone.") + checkpoint_file = Path(__file__).resolve().parents[ + 2 + ] / self.backbone_args.pop("pretrained") + + # Loading state directory. + state_dict = torch.load(checkpoint_file) + network_args = state_dict["network_args"] + weights = state_dict["model_state"] + + # Initializes the network with trained weights. + backbone = backbone_(**network_args) + backbone.load_state_dict(weights) + if "freeze" in self.backbone_args and self.backbone_args["freeze"] is True: + for params in backbone.parameters(): + params.requires_grad = False + + return backbone + else: + return backbone_(**self.backbone_args) def _configure_sliding_window(self) -> nn.Sequential: return nn.Sequential( @@ -69,13 +103,14 @@ class LineRecurrentNetwork(nn.Module): # Rearrange from a sequence of patches for feedforward network. b, t = x.shape[:2] x = rearrange(x, "b t c h w -> (b t) c h w", b=b, t=t) - x = self.encoder(x) + x = self.backbone(x) # Avgerage pooling. - x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t) if self.flatten else x - - # Linear layer between CNN and RNN - x = self.fc(x) + x = ( + self.fc(reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)) + if self.flatten + else rearrange(x, "(b t) h -> t b h", b=b, t=t) + ) # Sequence predictions. x, _ = self.rnn(x) diff --git a/src/text_recognizer/networks/losses.py b/src/text_recognizer/networks/losses.py new file mode 100644 index 0000000..73e0641 --- /dev/null +++ b/src/text_recognizer/networks/losses.py @@ -0,0 +1,31 @@ +"""Implementations of custom loss functions.""" +from pytorch_metric_learning import distances, losses, miners, reducers +from torch import nn +from torch import Tensor + + +class EmbeddingLoss: + """Metric loss for training encoders to produce information-rich latent embeddings.""" + + def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None: + self.distance = distances.CosineSimilarity() + self.reducer = reducers.ThresholdReducer(low=0) + self.loss_fn = losses.TripletMarginLoss( + margin=margin, distance=self.distance, reducer=self.reducer + ) + self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance) + + def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor: + """Computes the metric loss for the embeddings based on their labels. + + Args: + embeddings (Tensor): The laten vectors encoded by the network. + labels (Tensor): Labels of the embeddings. + + Returns: + Tensor: The metric loss for the embeddings. + + """ + hard_pairs = self.miner(embeddings, labels) + loss = self.loss_fn(embeddings, labels, hard_pairs) + return loss diff --git a/src/text_recognizer/networks/residual_network.py b/src/text_recognizer/networks/residual_network.py index 1b5d6b3..046600d 100644 --- a/src/text_recognizer/networks/residual_network.py +++ b/src/text_recognizer/networks/residual_network.py @@ -278,7 +278,8 @@ class ResidualNetworkEncoder(nn.Module): if self.stn is not None: x = self.stn(x) x = self.gate(x) - return self.blocks(x) + x = self.blocks(x) + return x class ResidualNetworkDecoder(nn.Module): diff --git a/src/text_recognizer/networks/transformer.py b/src/text_recognizer/networks/transformer.py index 868d739..c091ba0 100644 --- a/src/text_recognizer/networks/transformer.py +++ b/src/text_recognizer/networks/transformer.py @@ -1 +1,5 @@ """TBC.""" +from typing import Dict + +import torch +from torch import Tensor diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetworkEncoder_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetworkEncoder_weights.pt Binary files differnew file mode 100644 index 0000000..9f9deee --- /dev/null +++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetworkEncoder_weights.pt diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt Binary files differindex a25bcd1..0dc7eb5 100644 --- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt +++ b/src/text_recognizer/weights/CharacterModel_EmnistDataset_ResidualNetwork_weights.pt diff --git a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt Binary files differindex 9bd8ca2..93d34d7 100644 --- a/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt +++ b/src/text_recognizer/weights/LineCTCModel_IamLinesDataset_LineRecurrentNetwork_weights.pt diff --git a/src/training/experiments/iam_line_ctc_experiment.yml b/src/training/experiments/iam_line_ctc_experiment.yml deleted file mode 100644 index 141c74e..0000000 --- a/src/training/experiments/iam_line_ctc_experiment.yml +++ /dev/null @@ -1,94 +0,0 @@ -experiment_group: Sample Experiments -experiments: - - train_args: - batch_size: 24 - max_epochs: 128 - dataset: - type: IamLinesDataset - args: - subsample_fraction: null - transform: null - target_transform: null - train_args: - num_workers: 6 - train_fraction: 0.85 - model: LineCTCModel - metrics: [cer, wer] - network: - type: LineRecurrentNetwork - args: - # encoder: ResidualNetworkEncoder - # encoder_args: - # in_channels: 1 - # num_classes: 80 - # depths: [2, 2] - # block_sizes: [128, 128] - # activation: SELU - # stn: false - encoder: WideResidualNetwork - encoder_args: - in_channels: 1 - num_classes: 80 - depth: 16 - num_layers: 4 - width_factor: 2 - dropout_rate: 0.2 - activation: selu - use_decoder: false - flatten: true - input_size: 256 - hidden_size: 128 - num_layers: 2 - num_classes: 80 - patch_size: [28, 14] - stride: [1, 5] - criterion: - type: CTCLoss - args: - blank: 79 - optimizer: - type: AdamW - args: - lr: 1.e-03 - betas: [0.9, 0.999] - eps: 1.e-08 - weight_decay: false - amsgrad: false - # lr_scheduler: - # type: OneCycleLR - # args: - # max_lr: 1.e-02 - # epochs: null - # anneal_strategy: linear - lr_scheduler: - type: CosineAnnealingLR - args: - T_max: null - swa_args: - start: 75 - lr: 5.e-2 - callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, SWA] # EarlyStopping, OneCycleLR] - callback_args: - Checkpoint: - monitor: val_loss - mode: min - ProgressBar: - epochs: null - # log_batch_frequency: 100 - # EarlyStopping: - # monitor: val_loss - # min_delta: 0.0 - # patience: 7 - # mode: min - WandbCallback: - log_batch_frequency: 10 - WandbImageLogger: - num_examples: 6 - # OneCycleLR: - # null - SWA: - null - verbosity: 1 # 0, 1, 2 - resume_experiment: null - test: true - test_metric: test_cer diff --git a/src/training/experiments/line_ctc_experiment.yml b/src/training/experiments/line_ctc_experiment.yml index c21c6a2..432d1cc 100644 --- a/src/training/experiments/line_ctc_experiment.yml +++ b/src/training/experiments/line_ctc_experiment.yml @@ -1,55 +1,46 @@ -experiment_group: Sample Experiments +experiment_group: Lines Experiments experiments: - train_args: - batch_size: 64 - max_epochs: 32 + batch_size: 42 + max_epochs: &max_epochs 32 dataset: - type: EmnistLinesDataset + type: IamLinesDataset args: - subsample_fraction: 0.33 - max_length: 34 - min_overlap: 0 - max_overlap: 0.33 - num_samples: 10000 - seed: 4711 - blank: true + subsample_fraction: null + transform: null + target_transform: null train_args: - num_workers: 6 + num_workers: 8 train_fraction: 0.85 model: LineCTCModel metrics: [cer, wer] network: type: LineRecurrentNetwork args: - # encoder: ResidualNetworkEncoder - # encoder_args: - # in_channels: 1 - # num_classes: 81 - # depths: [2, 2] - # block_sizes: [64, 128] - # activation: SELU - # stn: false - encoder: WideResidualNetwork - encoder_args: + backbone: ResidualNetwork + backbone_args: in_channels: 1 - num_classes: 81 - depth: 16 - num_layers: 4 - width_factor: 2 - dropout_rate: 0.2 + num_classes: 64 # Embedding + depths: [2,2] + block_sizes: [32,64] activation: selu - use_decoder: false - flatten: true - input_size: 256 - hidden_size: 128 + stn: false + # encoder: ResidualNetwork + # encoder_args: + # pretrained: training/experiments/CharacterModel_EmnistDataset_ResidualNetwork/0917_203601/model/best.pt + # freeze: false + flatten: false + input_size: 64 + hidden_size: 64 + bidirectional: true num_layers: 2 - num_classes: 81 - patch_size: [28, 14] - stride: [1, 5] + num_classes: 80 + patch_size: [28, 18] + stride: [1, 4] criterion: type: CTCLoss args: - blank: 80 + blank: 79 optimizer: type: AdamW args: @@ -58,40 +49,42 @@ experiments: eps: 1.e-08 weight_decay: 5.e-4 amsgrad: false - # lr_scheduler: - # type: OneCycleLR - # args: - # max_lr: 1.e-03 - # epochs: null - # anneal_strategy: linear lr_scheduler: - type: CosineAnnealingLR + type: OneCycleLR args: - T_max: null + max_lr: 1.e-02 + epochs: *max_epochs + anneal_strategy: cos + pct_start: 0.475 + cycle_momentum: true + base_momentum: 0.85 + max_momentum: 0.9 + div_factor: 10 + final_div_factor: 10000 + interval: step + # lr_scheduler: + # type: CosineAnnealingLR + # args: + # T_max: *max_epochs swa_args: - start: 4 + start: 24 lr: 5.e-2 - callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, SWA] # EarlyStopping, OneCycleLR] + callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger] # EarlyStopping] callback_args: Checkpoint: monitor: val_loss mode: min ProgressBar: - epochs: null - log_batch_frequency: 100 + epochs: *max_epochs # EarlyStopping: # monitor: val_loss # min_delta: 0.0 - # patience: 5 + # patience: 10 # mode: min WandbCallback: log_batch_frequency: 10 WandbImageLogger: num_examples: 6 - # OneCycleLR: - # null - SWA: - null verbosity: 1 # 0, 1, 2 resume_experiment: null test: true diff --git a/src/training/experiments/sample_experiment.yml b/src/training/experiments/sample_experiment.yml index 17e220e..8664a15 100644 --- a/src/training/experiments/sample_experiment.yml +++ b/src/training/experiments/sample_experiment.yml @@ -2,7 +2,7 @@ experiment_group: Sample Experiments experiments: - train_args: batch_size: 256 - max_epochs: 32 + max_epochs: &max_epochs 32 dataset: type: EmnistDataset args: @@ -66,16 +66,17 @@ experiments: # type: OneCycleLR # args: # max_lr: 1.e-03 - # epochs: null + # epochs: *max_epochs # anneal_strategy: linear lr_scheduler: type: CosineAnnealingLR args: - T_max: null + T_max: *max_epochs + interval: epoch swa_args: start: 2 lr: 5.e-2 - callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, EarlyStopping, SWA] # OneCycleLR] + callbacks: [Checkpoint, ProgressBar, WandbCallback, WandbImageLogger, EarlyStopping] callback_args: Checkpoint: monitor: val_accuracy @@ -92,10 +93,6 @@ experiments: WandbImageLogger: num_examples: 4 use_transpose: true - # OneCycleLR: - # null - SWA: - null verbosity: 0 # 0, 1, 2 resume_experiment: null test: true diff --git a/src/training/run_experiment.py b/src/training/run_experiment.py index 286b0c6..a347d9f 100644 --- a/src/training/run_experiment.py +++ b/src/training/run_experiment.py @@ -10,6 +10,7 @@ from typing import Callable, Dict, List, Tuple, Type import click from loguru import logger +import numpy as np import torch from tqdm import tqdm from training.gpu_manager import GPUManager @@ -20,11 +21,12 @@ import yaml from text_recognizer.models import Model +from text_recognizer.networks import losses EXPERIMENTS_DIRNAME = Path(__file__).parents[0].resolve() / "experiments" - +CUSTOM_LOSSES = ["EmbeddingLoss"] DEFAULT_TRAIN_ARGS = {"batch_size": 64, "epochs": 16} @@ -69,21 +71,6 @@ def create_experiment_dir(experiment_config: Dict) -> Path: return experiment_dir, log_dir, model_dir -def check_args(args: Dict, train_args: Dict) -> Dict: - """Checks that the arguments are not None.""" - args = args or {} - - # I just want to set total epochs in train args, instead of changing all parameter. - if "epochs" in args and args["epochs"] is None: - args["epochs"] = train_args["max_epochs"] - - # For CosineAnnealingLR. - if "T_max" in args and args["T_max"] is None: - args["T_max"] = train_args["max_epochs"] - - return args or {} - - def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict]: """Loads all modules and arguments.""" # Import the data loader arguments. @@ -115,8 +102,12 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] network_args = experiment_config["network"].get("args", {}) # Criterion - criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"]) - criterion_args = experiment_config["criterion"].get("args", {}) + if experiment_config["criterion"]["type"] in CUSTOM_LOSSES: + criterion_ = getattr(losses, experiment_config["criterion"]["type"]) + criterion_args = experiment_config["criterion"].get("args", {}) + else: + criterion_ = getattr(torch.nn, experiment_config["criterion"]["type"]) + criterion_args = experiment_config["criterion"].get("args", {}) # Optimizers optimizer_ = getattr(torch.optim, experiment_config["optimizer"]["type"]) @@ -129,13 +120,11 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] lr_scheduler_ = getattr( torch.optim.lr_scheduler, experiment_config["lr_scheduler"]["type"] ) - lr_scheduler_args = check_args( - experiment_config["lr_scheduler"].get("args", {}), train_args - ) + lr_scheduler_args = experiment_config["lr_scheduler"].get("args", {}) or {} # SWA scheduler. if "swa_args" in experiment_config: - swa_args = check_args(experiment_config.get("swa_args", {}), train_args) + swa_args = experiment_config.get("swa_args", {}) or {} else: swa_args = None @@ -159,19 +148,15 @@ def load_modules_and_arguments(experiment_config: Dict) -> Tuple[Callable, Dict] def configure_callbacks(experiment_config: Dict, model_dir: Dict) -> CallbackList: """Configure a callback list for trainer.""" - train_args = experiment_config.get("train_args", {}) - if "Checkpoint" in experiment_config["callback_args"]: experiment_config["callback_args"]["Checkpoint"]["checkpoint_path"] = model_dir - # Callbacks + # Initializes callbacks. callback_modules = importlib.import_module("training.trainer.callbacks") - callbacks = [ - getattr(callback_modules, callback)( - **check_args(experiment_config["callback_args"][callback], train_args) - ) - for callback in experiment_config["callbacks"] - ] + callbacks = [] + for callback in experiment_config["callbacks"]: + args = experiment_config["callback_args"][callback] or {} + callbacks.append(getattr(callback_modules, callback)(**args)) return callbacks @@ -207,11 +192,35 @@ def load_from_checkpoint(model: Type[Model], log_dir: Path, model_dir: Path) -> model.load_checkpoint(checkpoint_path) +def evaluate_embedding(model: Type[Model]) -> Dict: + """Evaluates the embedding space.""" + from pytorch_metric_learning import testers + from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator + + accuracy_calculator = AccuracyCalculator( + include=("mean_average_precision_at_r",), k=10 + ) + + def get_all_embeddings(model: Type[Model]) -> Tuple: + tester = testers.BaseTester() + return tester.get_all_embeddings(model.test_dataset, model.network) + + embeddings, labels = get_all_embeddings(model) + logger.info("Computing embedding accuracy") + accuracies = accuracy_calculator.get_accuracy( + embeddings, embeddings, np.squeeze(labels), np.squeeze(labels), True + ) + logger.info( + f"Test set accuracy (MAP@10) = {accuracies['mean_average_precision_at_r']}" + ) + return accuracies + + def run_experiment( experiment_config: Dict, save_weights: bool, device: str, use_wandb: bool = False ) -> None: """Runs an experiment.""" - logger.info(f"Experiment config: {json.dumps(experiment_config, indent=2)}") + logger.info(f"Experiment config: {json.dumps(experiment_config)}") # Create new experiment. experiment_dir, log_dir, model_dir = create_experiment_dir(experiment_config) @@ -272,7 +281,11 @@ def run_experiment( model.load_from_checkpoint(model_dir / "best.pt") logger.info("Running inference on test set.") - score = trainer.test(model) + if experiment_config["criterion"]["type"] in CUSTOM_LOSSES: + logger.info("Evaluating embedding.") + score = evaluate_embedding(model) + else: + score = trainer.test(model) logger.info(f"Test set evaluation: {score}") diff --git a/src/training/trainer/callbacks/__init__.py b/src/training/trainer/callbacks/__init__.py index c81e4bf..e1bd858 100644 --- a/src/training/trainer/callbacks/__init__.py +++ b/src/training/trainer/callbacks/__init__.py @@ -3,12 +3,7 @@ from .base import Callback, CallbackList from .checkpoint import Checkpoint from .early_stopping import EarlyStopping from .lr_schedulers import ( - CosineAnnealingLR, - CyclicLR, - MultiStepLR, - OneCycleLR, - ReduceLROnPlateau, - StepLR, + LRScheduler, SWA, ) from .progress_bar import ProgressBar @@ -18,15 +13,10 @@ __all__ = [ "Callback", "CallbackList", "Checkpoint", - "CosineAnnealingLR", "EarlyStopping", + "LRScheduler", "WandbCallback", "WandbImageLogger", - "CyclicLR", - "MultiStepLR", - "OneCycleLR", "ProgressBar", - "ReduceLROnPlateau", - "StepLR", "SWA", ] diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/src/training/trainer/callbacks/lr_schedulers.py index bb41d2d..907e292 100644 --- a/src/training/trainer/callbacks/lr_schedulers.py +++ b/src/training/trainer/callbacks/lr_schedulers.py @@ -7,113 +7,27 @@ from training.trainer.callbacks import Callback from text_recognizer.models import Model -class StepLR(Callback): - """Callback for StepLR.""" +class LRScheduler(Callback): + """Generic learning rate scheduler callback.""" def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every epoch.""" - self.lr_scheduler.step() - - -class MultiStepLR(Callback): - """Callback for MultiStepLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every epoch.""" - self.lr_scheduler.step() - - -class ReduceLROnPlateau(Callback): - """Callback for ReduceLROnPlateau.""" - - def __init__(self) -> None: - """Initializes the callback.""" super().__init__() - self.lr_scheduler = None def set_model(self, model: Type[Model]) -> None: """Sets the model and lr scheduler.""" self.model = model - self.lr_scheduler = self.model.lr_scheduler + self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"] + self.interval = self.model.lr_scheduler["interval"] def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every epoch.""" - val_loss = logs["val_loss"] - self.lr_scheduler.step(val_loss) - - -class CyclicLR(Callback): - """Callback for CyclicLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every training batch.""" - self.lr_scheduler.step() - - -class OneCycleLR(Callback): - """Callback for OneCycleLR.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler + if self.interval == "epoch": + self.lr_scheduler.step() def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every training batch.""" - self.lr_scheduler.step() - - -class CosineAnnealingLR(Callback): - """Callback for Cosine Annealing.""" - - def __init__(self) -> None: - """Initializes the callback.""" - super().__init__() - self.lr_scheduler = None - - def set_model(self, model: Type[Model]) -> None: - """Sets the model and lr scheduler.""" - self.model = model - self.lr_scheduler = self.model.lr_scheduler - - def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: - """Takes a step at the end of every epoch.""" - self.lr_scheduler.step() + if self.interval == "step": + self.lr_scheduler.step() class SWA(Callback): @@ -122,21 +36,32 @@ class SWA(Callback): def __init__(self) -> None: """Initializes the callback.""" super().__init__() + self.lr_scheduler = None + self.interval = None self.swa_scheduler = None + self.swa_start = None + self.current_epoch = 1 def set_model(self, model: Type[Model]) -> None: """Sets the model and lr scheduler.""" self.model = model - self.swa_start = self.model.swa_start - self.swa_scheduler = self.model.lr_scheduler - self.lr_scheduler = self.model.lr_scheduler + self.lr_scheduler = self.model.lr_scheduler["lr_scheduler"] + self.interval = self.model.lr_scheduler["interval"] + self.swa_scheduler = self.model.swa_scheduler["swa_scheduler"] + self.swa_start = self.model.swa_scheduler["swa_start"] def on_epoch_end(self, epoch: int, logs: Optional[Dict] = None) -> None: """Takes a step at the end of every training batch.""" if epoch > self.swa_start: self.model.swa_network.update_parameters(self.model.network) self.swa_scheduler.step() - else: + elif self.interval == "epoch": + self.lr_scheduler.step() + self.current_epoch = epoch + + def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: + """Takes a step at the end of every training batch.""" + if self.current_epoch < self.swa_start and self.interval == "step": self.lr_scheduler.step() def on_fit_end(self) -> None: diff --git a/src/training/trainer/callbacks/progress_bar.py b/src/training/trainer/callbacks/progress_bar.py index 7829fa0..6c4305a 100644 --- a/src/training/trainer/callbacks/progress_bar.py +++ b/src/training/trainer/callbacks/progress_bar.py @@ -11,6 +11,7 @@ class ProgressBar(Callback): def __init__(self, epochs: int, log_batch_frequency: int = None) -> None: """Initializes the tqdm callback.""" self.epochs = epochs + print(epochs, type(epochs)) self.log_batch_frequency = log_batch_frequency self.progress_bar = None self.val_metrics = {} diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/src/training/trainer/callbacks/wandb_callbacks.py index 6643a44..d2df4d7 100644 --- a/src/training/trainer/callbacks/wandb_callbacks.py +++ b/src/training/trainer/callbacks/wandb_callbacks.py @@ -32,6 +32,7 @@ class WandbCallback(Callback): def on_train_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: """Logs training metrics.""" if logs is not None: + logs["lr"] = self.model.optimizer.param_groups[0]["lr"] self._on_batch_end(batch, logs) def on_validation_batch_end(self, batch: int, logs: Optional[Dict] = None) -> None: diff --git a/src/training/trainer/train.py b/src/training/trainer/train.py index b240157..bd6a491 100644 --- a/src/training/trainer/train.py +++ b/src/training/trainer/train.py @@ -9,7 +9,7 @@ import numpy as np import torch from torch import Tensor from torch.optim.swa_utils import update_bn -from training.trainer.callbacks import Callback, CallbackList +from training.trainer.callbacks import Callback, CallbackList, LRScheduler, SWA from training.trainer.util import log_val_metric, RunningAverage import wandb @@ -47,8 +47,14 @@ class Trainer: self.model = None def _configure_callbacks(self) -> None: + """Instantiate the CallbackList.""" if not self.callbacks_configured: - # Instantiate a CallbackList. + # If learning rate schedulers are present, they need to be added to the callbacks. + if self.model.swa_scheduler is not None: + self.callbacks.append(SWA()) + elif self.model.lr_scheduler is not None: + self.callbacks.append(LRScheduler()) + self.callbacks = CallbackList(self.model, self.callbacks) def compute_metrics( @@ -91,7 +97,7 @@ class Trainer: # Forward pass. # Get the network prediction. - output = self.model.network(data) + output = self.model.forward(data) # Compute the loss. loss = self.model.loss_fn(output, targets) @@ -130,7 +136,6 @@ class Trainer: batch: int, samples: Tuple[Tensor, Tensor], loss_avg: Type[RunningAverage], - use_swa: bool = False, ) -> Dict: """Performs the validation step.""" # Pass the tensor to the device for computation. @@ -143,10 +148,7 @@ class Trainer: # Forward pass. # Get the network prediction. # Use SWA if available and using test dataset. - if use_swa and self.model.swa_network is None: - output = self.model.swa_network(data) - else: - output = self.model.network(data) + output = self.model.forward(data) # Compute the loss. loss = self.model.loss_fn(output, targets) @@ -238,7 +240,7 @@ class Trainer: self.model.eval() # Check if SWA network is available. - use_swa = True if self.model.swa_network is not None else False + self.model.use_swa_model() # Running average for the loss. loss_avg = RunningAverage() @@ -247,7 +249,7 @@ class Trainer: summary = [] for batch, samples in enumerate(self.model.test_dataloader()): - metrics = self.validation_step(batch, samples, loss_avg, use_swa) + metrics = self.validation_step(batch, samples, loss_avg) summary.append(metrics) # Compute mean of all test metrics. |