summaryrefslogtreecommitdiff
path: root/notebooks/00-testing-stuff-out.ipynb
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
commit7e8e54e84c63171e748bbf09516fd517e6821ace (patch)
tree996093f75a5d488dddf7ea1f159ed343a561ef89 /notebooks/00-testing-stuff-out.ipynb
parentb0719d84138b6bbe5f04a4982dfca673aea1a368 (diff)
Inital commit for refactoring to lightning
Diffstat (limited to 'notebooks/00-testing-stuff-out.ipynb')
-rw-r--r--notebooks/00-testing-stuff-out.ipynb1469
1 files changed, 1469 insertions, 0 deletions
diff --git a/notebooks/00-testing-stuff-out.ipynb b/notebooks/00-testing-stuff-out.ipynb
new file mode 100644
index 0000000..becd918
--- /dev/null
+++ b/notebooks/00-testing-stuff-out.ipynb
@@ -0,0 +1,1469 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "%matplotlib inline\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "from PIL import Image\n",
+ "import torch.nn.functional as F\n",
+ "import torch\n",
+ "from torch import nn\n",
+ "from torchsummary import summary\n",
+ "from importlib.util import find_spec\n",
+ "if find_spec(\"text_recognizer\") is None:\n",
+ " import sys\n",
+ " sys.path.append('..')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks import CNN, TDS2d"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tds2d = TDS2d(**{\n",
+ " \"depth\" : 4,\n",
+ " \"tds_groups\" : [\n",
+ " { \"channels\" : 4, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n",
+ " { \"channels\" : 32, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n",
+ " { \"channels\" : 64, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n",
+ " { \"channels\" : 128, \"num_blocks\" : 3, \"stride\" : [2, 1] },\n",
+ " ],\n",
+ " \"kernel_size\" : [5, 7],\n",
+ " \"dropout_rate\" : 0.1\n",
+ " }, input_dim=32, output_dim=128)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "TDS2d(\n",
+ " (tds): Sequential(\n",
+ " (0): Conv2d(1, 16, kernel_size=[5, 7], stride=[2, 2], padding=(2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (4): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(4, 4, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=16, out_features=16, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=16, out_features=16, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (5): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(4, 4, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=16, out_features=16, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=16, out_features=16, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (6): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(4, 4, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=16, out_features=16, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=16, out_features=16, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (7): Conv2d(16, 128, kernel_size=[5, 7], stride=[2, 2], padding=(2, 3))\n",
+ " (8): ReLU(inplace=True)\n",
+ " (9): Dropout(p=0.1, inplace=False)\n",
+ " (10): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (11): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(32, 32, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=128, out_features=128, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=128, out_features=128, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (12): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(32, 32, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=128, out_features=128, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=128, out_features=128, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (13): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(32, 32, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=128, out_features=128, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=128, out_features=128, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (14): Conv2d(128, 256, kernel_size=[5, 7], stride=[2, 2], padding=(2, 3))\n",
+ " (15): ReLU(inplace=True)\n",
+ " (16): Dropout(p=0.1, inplace=False)\n",
+ " (17): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (18): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(64, 64, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=256, out_features=256, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=256, out_features=256, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (19): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(64, 64, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=256, out_features=256, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=256, out_features=256, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (20): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(64, 64, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=256, out_features=256, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=256, out_features=256, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (21): Conv2d(256, 512, kernel_size=[5, 7], stride=[2, 1], padding=(2, 3))\n",
+ " (22): ReLU(inplace=True)\n",
+ " (23): Dropout(p=0.1, inplace=False)\n",
+ " (24): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (25): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(128, 128, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (26): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(128, 128, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (27): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(128, 128, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (fc): Linear(in_features=1024, out_features=128, bias=True)\n",
+ ")"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tds2d"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "===============================================================================================\n",
+ "Layer (type:depth-idx) Output Shape Param #\n",
+ "===============================================================================================\n",
+ "├─Sequential: 1-1 [-1, 512, 2, 119] --\n",
+ "| └─Conv2d: 2-1 [-1, 16, 14, 476] 576\n",
+ "| └─ReLU: 2-2 [-1, 16, 14, 476] --\n",
+ "| └─Dropout: 2-3 [-1, 16, 14, 476] --\n",
+ "| └─InstanceNorm2d: 2-4 [-1, 16, 14, 476] 32\n",
+ "| └─TDSBlock2d: 2-5 [-1, 16, 14, 476] --\n",
+ "| | └─Sequential: 3-1 [-1, 4, 4, 14, 476] 564\n",
+ "| | └─Sequential: 3-2 [-1, 476, 14, 16] 544\n",
+ "| └─TDSBlock2d: 2-6 [-1, 16, 14, 476] --\n",
+ "| | └─Sequential: 3-3 [-1, 4, 4, 14, 476] 564\n",
+ "| | └─Sequential: 3-4 [-1, 476, 14, 16] 544\n",
+ "| └─TDSBlock2d: 2-7 [-1, 16, 14, 476] --\n",
+ "| | └─Sequential: 3-5 [-1, 4, 4, 14, 476] 564\n",
+ "| | └─Sequential: 3-6 [-1, 476, 14, 16] 544\n",
+ "| └─Conv2d: 2-8 [-1, 128, 7, 238] 71,808\n",
+ "| └─ReLU: 2-9 [-1, 128, 7, 238] --\n",
+ "| └─Dropout: 2-10 [-1, 128, 7, 238] --\n",
+ "| └─InstanceNorm2d: 2-11 [-1, 128, 7, 238] 256\n",
+ "| └─TDSBlock2d: 2-12 [-1, 128, 7, 238] --\n",
+ "| | └─Sequential: 3-7 [-1, 32, 4, 7, 238] 35,872\n",
+ "| | └─Sequential: 3-8 [-1, 238, 7, 128] 33,024\n",
+ "| └─TDSBlock2d: 2-13 [-1, 128, 7, 238] --\n",
+ "| | └─Sequential: 3-9 [-1, 32, 4, 7, 238] 35,872\n",
+ "| | └─Sequential: 3-10 [-1, 238, 7, 128] 33,024\n",
+ "| └─TDSBlock2d: 2-14 [-1, 128, 7, 238] --\n",
+ "| | └─Sequential: 3-11 [-1, 32, 4, 7, 238] 35,872\n",
+ "| | └─Sequential: 3-12 [-1, 238, 7, 128] 33,024\n",
+ "| └─Conv2d: 2-15 [-1, 256, 4, 119] 1,147,136\n",
+ "| └─ReLU: 2-16 [-1, 256, 4, 119] --\n",
+ "| └─Dropout: 2-17 [-1, 256, 4, 119] --\n",
+ "| └─InstanceNorm2d: 2-18 [-1, 256, 4, 119] 512\n",
+ "| └─TDSBlock2d: 2-19 [-1, 256, 4, 119] --\n",
+ "| | └─Sequential: 3-13 [-1, 64, 4, 4, 119] 143,424\n",
+ "| | └─Sequential: 3-14 [-1, 119, 4, 256] 131,584\n",
+ "| └─TDSBlock2d: 2-20 [-1, 256, 4, 119] --\n",
+ "| | └─Sequential: 3-15 [-1, 64, 4, 4, 119] 143,424\n",
+ "| | └─Sequential: 3-16 [-1, 119, 4, 256] 131,584\n",
+ "| └─TDSBlock2d: 2-21 [-1, 256, 4, 119] --\n",
+ "| | └─Sequential: 3-17 [-1, 64, 4, 4, 119] 143,424\n",
+ "| | └─Sequential: 3-18 [-1, 119, 4, 256] 131,584\n",
+ "| └─Conv2d: 2-22 [-1, 512, 2, 119] 4,588,032\n",
+ "| └─ReLU: 2-23 [-1, 512, 2, 119] --\n",
+ "| └─Dropout: 2-24 [-1, 512, 2, 119] --\n",
+ "| └─InstanceNorm2d: 2-25 [-1, 512, 2, 119] 1,024\n",
+ "| └─TDSBlock2d: 2-26 [-1, 512, 2, 119] --\n",
+ "| | └─Sequential: 3-19 [-1, 128, 4, 2, 119] 573,568\n",
+ "| | └─Sequential: 3-20 [-1, 119, 2, 512] 525,312\n",
+ "| └─TDSBlock2d: 2-27 [-1, 512, 2, 119] --\n",
+ "| | └─Sequential: 3-21 [-1, 128, 4, 2, 119] 573,568\n",
+ "| | └─Sequential: 3-22 [-1, 119, 2, 512] 525,312\n",
+ "| └─TDSBlock2d: 2-28 [-1, 512, 2, 119] --\n",
+ "| | └─Sequential: 3-23 [-1, 128, 4, 2, 119] 573,568\n",
+ "| | └─Sequential: 3-24 [-1, 119, 2, 512] 525,312\n",
+ "├─Linear: 1-2 [-1, 119, 128] 131,200\n",
+ "===============================================================================================\n",
+ "Total params: 10,272,252\n",
+ "Trainable params: 10,272,252\n",
+ "Non-trainable params: 0\n",
+ "Total mult-adds (G): 5.00\n",
+ "===============================================================================================\n",
+ "Input size (MB): 0.10\n",
+ "Forward/backward pass size (MB): 73.21\n",
+ "Params size (MB): 39.19\n",
+ "Estimated Total Size (MB): 112.50\n",
+ "===============================================================================================\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "===============================================================================================\n",
+ "Layer (type:depth-idx) Output Shape Param #\n",
+ "===============================================================================================\n",
+ "├─Sequential: 1-1 [-1, 512, 2, 119] --\n",
+ "| └─Conv2d: 2-1 [-1, 16, 14, 476] 576\n",
+ "| └─ReLU: 2-2 [-1, 16, 14, 476] --\n",
+ "| └─Dropout: 2-3 [-1, 16, 14, 476] --\n",
+ "| └─InstanceNorm2d: 2-4 [-1, 16, 14, 476] 32\n",
+ "| └─TDSBlock2d: 2-5 [-1, 16, 14, 476] --\n",
+ "| | └─Sequential: 3-1 [-1, 4, 4, 14, 476] 564\n",
+ "| | └─Sequential: 3-2 [-1, 476, 14, 16] 544\n",
+ "| └─TDSBlock2d: 2-6 [-1, 16, 14, 476] --\n",
+ "| | └─Sequential: 3-3 [-1, 4, 4, 14, 476] 564\n",
+ "| | └─Sequential: 3-4 [-1, 476, 14, 16] 544\n",
+ "| └─TDSBlock2d: 2-7 [-1, 16, 14, 476] --\n",
+ "| | └─Sequential: 3-5 [-1, 4, 4, 14, 476] 564\n",
+ "| | └─Sequential: 3-6 [-1, 476, 14, 16] 544\n",
+ "| └─Conv2d: 2-8 [-1, 128, 7, 238] 71,808\n",
+ "| └─ReLU: 2-9 [-1, 128, 7, 238] --\n",
+ "| └─Dropout: 2-10 [-1, 128, 7, 238] --\n",
+ "| └─InstanceNorm2d: 2-11 [-1, 128, 7, 238] 256\n",
+ "| └─TDSBlock2d: 2-12 [-1, 128, 7, 238] --\n",
+ "| | └─Sequential: 3-7 [-1, 32, 4, 7, 238] 35,872\n",
+ "| | └─Sequential: 3-8 [-1, 238, 7, 128] 33,024\n",
+ "| └─TDSBlock2d: 2-13 [-1, 128, 7, 238] --\n",
+ "| | └─Sequential: 3-9 [-1, 32, 4, 7, 238] 35,872\n",
+ "| | └─Sequential: 3-10 [-1, 238, 7, 128] 33,024\n",
+ "| └─TDSBlock2d: 2-14 [-1, 128, 7, 238] --\n",
+ "| | └─Sequential: 3-11 [-1, 32, 4, 7, 238] 35,872\n",
+ "| | └─Sequential: 3-12 [-1, 238, 7, 128] 33,024\n",
+ "| └─Conv2d: 2-15 [-1, 256, 4, 119] 1,147,136\n",
+ "| └─ReLU: 2-16 [-1, 256, 4, 119] --\n",
+ "| └─Dropout: 2-17 [-1, 256, 4, 119] --\n",
+ "| └─InstanceNorm2d: 2-18 [-1, 256, 4, 119] 512\n",
+ "| └─TDSBlock2d: 2-19 [-1, 256, 4, 119] --\n",
+ "| | └─Sequential: 3-13 [-1, 64, 4, 4, 119] 143,424\n",
+ "| | └─Sequential: 3-14 [-1, 119, 4, 256] 131,584\n",
+ "| └─TDSBlock2d: 2-20 [-1, 256, 4, 119] --\n",
+ "| | └─Sequential: 3-15 [-1, 64, 4, 4, 119] 143,424\n",
+ "| | └─Sequential: 3-16 [-1, 119, 4, 256] 131,584\n",
+ "| └─TDSBlock2d: 2-21 [-1, 256, 4, 119] --\n",
+ "| | └─Sequential: 3-17 [-1, 64, 4, 4, 119] 143,424\n",
+ "| | └─Sequential: 3-18 [-1, 119, 4, 256] 131,584\n",
+ "| └─Conv2d: 2-22 [-1, 512, 2, 119] 4,588,032\n",
+ "| └─ReLU: 2-23 [-1, 512, 2, 119] --\n",
+ "| └─Dropout: 2-24 [-1, 512, 2, 119] --\n",
+ "| └─InstanceNorm2d: 2-25 [-1, 512, 2, 119] 1,024\n",
+ "| └─TDSBlock2d: 2-26 [-1, 512, 2, 119] --\n",
+ "| | └─Sequential: 3-19 [-1, 128, 4, 2, 119] 573,568\n",
+ "| | └─Sequential: 3-20 [-1, 119, 2, 512] 525,312\n",
+ "| └─TDSBlock2d: 2-27 [-1, 512, 2, 119] --\n",
+ "| | └─Sequential: 3-21 [-1, 128, 4, 2, 119] 573,568\n",
+ "| | └─Sequential: 3-22 [-1, 119, 2, 512] 525,312\n",
+ "| └─TDSBlock2d: 2-28 [-1, 512, 2, 119] --\n",
+ "| | └─Sequential: 3-23 [-1, 128, 4, 2, 119] 573,568\n",
+ "| | └─Sequential: 3-24 [-1, 119, 2, 512] 525,312\n",
+ "├─Linear: 1-2 [-1, 119, 128] 131,200\n",
+ "===============================================================================================\n",
+ "Total params: 10,272,252\n",
+ "Trainable params: 10,272,252\n",
+ "Non-trainable params: 0\n",
+ "Total mult-adds (G): 5.00\n",
+ "===============================================================================================\n",
+ "Input size (MB): 0.10\n",
+ "Forward/backward pass size (MB): 73.21\n",
+ "Params size (MB): 39.19\n",
+ "Estimated Total Size (MB): 112.50\n",
+ "==============================================================================================="
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "summary(tds2d, (1, 28, 952), device=\"cpu\", depth=3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t = torch.randn(2,1, 28, 952)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 119, 128])"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tds2d(t).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cnn = CNN().cuda()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "i = nn.Sequential(nn.Conv2d(1,1,1,1))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "nn.Sequential(i,i)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cnn(t).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks.vqvae import Encoder, Decoder, VQVAE"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vqvae = VQVAE(1, [32, 128, 128, 256], [4, 4, 4, 4], [2, 2, [1, 2], [1, 2]], 2, 32, 256, [[6, 119], [7, 238]])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t = torch.randn(2, 1, 28, 952)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x, l = vqvae(t)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "5 * 59 / 10"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "summary(vqvae, (1, 28, 952), device=\"cpu\", depth=3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "up = nn.Upsample([4, 59])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "up(tt).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tt.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class GEGLU(nn.Module):\n",
+ " def __init__(self, dim_in, dim_out):\n",
+ " super().__init__()\n",
+ " self.proj = nn.Linear(dim_in, dim_out * 2)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x, gate = self.proj(x).chunk(2, dim = -1)\n",
+ " return x * F.gelu(gate)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "e = GEGLU(256, 2048)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "e(t).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "emb = nn.Embedding(56, 256)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with torch.no_grad():\n",
+ " e = emb(torch.Tensor([55]).long())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from einops import repeat"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ee = repeat(e, \"() n -> b n\", b=16)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "emb.device"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ee"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ee.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t = torch.randn(16, 10, 256)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t = torch.cat((ee.unsqueeze(1), t, ee.unsqueeze(1)), dim=1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "e.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks.residual_network import IdentityBlock, ResidualBlock, BasicBlock, BottleNeckBlock, ResidualLayer, ResidualNetwork, ResidualNetworkEncoder"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks import WideResidualNetwork"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wr = WideResidualNetwork(\n",
+ " in_channels= 1,\n",
+ " num_classes= 80,\n",
+ " in_planes=64,\n",
+ " depth=10,\n",
+ " num_layers=4,\n",
+ " width_factor=2,\n",
+ " num_stages=[64, 128, 256, 256],\n",
+ " dropout_rate= 0.1,\n",
+ " activation= \"SELU\",\n",
+ " use_decoder= False,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from torchsummary import summary"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "backbone = ResidualNetworkEncoder(1, [64, 65, 66, 67, 68], [2, 2, 2, 2, 2])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "summary(backbone, (1, 28, 952), device=\"cpu\", depth=3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ " backbone = nn.Sequential(\n",
+ " *list(wr.children())[:][:]\n",
+ " )\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "backbone"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "summary(wr, (1, 28, 952), device=\"cpu\", depth=3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "a = torch.rand(1, 1, 28, 952)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "b = wr(a)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from einops import rearrange"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "b = rearrange(b, \"b c h w -> b w c h\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "c = nn.AdaptiveAvgPool2d((None, 1))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "d = c(b)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "d.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "d.squeeze(3).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "b.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from torch import nn"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "32 + 64"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "3 * 112"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "col_embed = nn.Parameter(torch.rand(1000, 256 // 2))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "W, H = 196, 4"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "col_embed[:W].unsqueeze(0).repeat(H, 1, 1).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "col_embed[:H].unsqueeze(1).repeat(1, W, 1).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ " torch.cat(\n",
+ " [\n",
+ " col_embed[:W].unsqueeze(0).repeat(H, 1, 1),\n",
+ " col_embed[:H].unsqueeze(1).repeat(1, W, 1),\n",
+ " ],\n",
+ " dim=-1,\n",
+ " ).unsqueeze(0).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "4 * 196"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "target = torch.tensor([1,1,12,1,1,1,1,1,9,9,9,9,9,9])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "torch.nonzero(target == 9, as_tuple=False)[0].item()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "target[:9]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "np.inf"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks.transformer.positional_encoding import PositionalEncoding"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plt.figure(figsize=(15, 5))\n",
+ "pe = PositionalEncoding(20, 0)\n",
+ "y = pe.forward(torch.zeros(1, 100, 20))\n",
+ "plt.plot(np.arange(100), y[0, :, 4:8].data.numpy())\n",
+ "plt.legend([\"dim %d\"%p for p in [4,5,6,7]])\n",
+ "None"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks.densenet import DenseNet,_DenseLayer,_DenseBlock"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dnet = DenseNet(12, (6, 12, 10), 1, 24, 80, 4, 0, True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "216 / 8"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "summary(dnet, (1, 28, 952), device=\"cpu\", depth=3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ " backbone = nn.Sequential(\n",
+ " *list(dnet.children())[:][:-4]\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "backbone"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks import WideResidualNetwork"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "w = WideResidualNetwork(\n",
+ " in_channels = 1,\n",
+ " in_planes = 32,\n",
+ " num_classes = 80,\n",
+ " depth = 10,\n",
+ " width_factor = 1,\n",
+ " dropout_rate = 0.0,\n",
+ " num_layers = 5,\n",
+ " activation = \"relu\",\n",
+ " use_decoder = False,)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "summary(w, (1, 28, 952), device=\"cpu\", depth=2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "sz= 5"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mask = torch.triu(torch.ones(sz, sz), 1)\n",
+ "mask = mask.masked_fill(mask==1, float('-inf'))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "h = torch.rand(1, 256, 10, 10)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "h.flatten(2).permute(2, 0, 1).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "h.flatten(2).permute(2, 0, 1).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mask\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pred = torch.Tensor([1,21,2,45,31, 81, 1, 79, 79, 79, 2,1,1,1,1, 81, 1, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()\n",
+ "target = torch.Tensor([1,1,1,1,1, 81, 79, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mask = (target != 79)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mask"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pred * mask"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "target * mask"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.models.metrics import accuracy"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pad_indcies = torch.nonzero(target == 79, as_tuple=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t1 = torch.nonzero(target == 81, as_tuple=False).squeeze(1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "target.shape[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t2 = torch.arange(10, target.shape[0] + 1, 10)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for start, stop in zip(t1, t2):\n",
+ " pred[start+1:stop] = 79"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pred"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "[pred[start+1:stop] = 79 for start, stop in zip(t1, t2)]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "pad_indcies"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pred[pad_indcies:pad_indcies] = 79"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pred.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "target.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "accuracy(pred, target)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "acc = (pred == target).sum().float() / target.shape[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "acc"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.1"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}