summaryrefslogtreecommitdiff
path: root/notebooks/05c-test-model-end-to-end.ipynb
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-06 14:19:37 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-06 14:19:37 +0200
commit263f2b7158d76bc0adad45309625910c0fa7b1fe (patch)
tree6af1782e39812c8d7ff8a853195adc32f67f56c2 /notebooks/05c-test-model-end-to-end.ipynb
parent3ab82ad36bce6fa698a13a029a0694b75a5947b7 (diff)
Remove lr args from model, add Cosine lr, fix to vqvae stack
Diffstat (limited to 'notebooks/05c-test-model-end-to-end.ipynb')
-rw-r--r--notebooks/05c-test-model-end-to-end.ipynb119
1 files changed, 110 insertions, 9 deletions
diff --git a/notebooks/05c-test-model-end-to-end.ipynb b/notebooks/05c-test-model-end-to-end.ipynb
index 7996257..b26a1fe 100644
--- a/notebooks/05c-test-model-end-to-end.ipynb
+++ b/notebooks/05c-test-model-end-to-end.ipynb
@@ -2,10 +2,19 @@
"cells": [
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 4,
"id": "1e40a88b",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The autoreload extension is already loaded. To reload it, use:\n",
+ " %reload_ext autoreload\n"
+ ]
+ }
+ ],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
@@ -25,7 +34,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 5,
"id": "38fb3d9d-a163-4b72-981f-f31b51be39f2",
"metadata": {},
"outputs": [],
@@ -37,10 +46,46 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 10,
"id": "74780b21-3313-452b-b580-703cac878416",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "encoder:\n",
+ " _target_: text_recognizer.networks.vqvae.encoder.Encoder\n",
+ " in_channels: 1\n",
+ " hidden_dim: 32\n",
+ " channels_multipliers:\n",
+ " - 1\n",
+ " - 2\n",
+ " - 4\n",
+ " - 4\n",
+ " - 4\n",
+ " dropout_rate: 0.25\n",
+ "decoder:\n",
+ " _target_: text_recognizer.networks.vqvae.decoder.Decoder\n",
+ " out_channels: 1\n",
+ " hidden_dim: 32\n",
+ " channels_multipliers:\n",
+ " - 4\n",
+ " - 4\n",
+ " - 4\n",
+ " - 2\n",
+ " - 1\n",
+ " dropout_rate: 0.25\n",
+ "_target_: text_recognizer.networks.vqvae.vqvae.VQVAE\n",
+ "hidden_dim: 128\n",
+ "embedding_dim: 32\n",
+ "num_embeddings: 1024\n",
+ "decay: 0.99\n",
+ "\n",
+ "{'encoder': {'_target_': 'text_recognizer.networks.vqvae.encoder.Encoder', 'in_channels': 1, 'hidden_dim': 32, 'channels_multipliers': [1, 2, 4, 4, 4], 'dropout_rate': 0.25}, 'decoder': {'_target_': 'text_recognizer.networks.vqvae.decoder.Decoder', 'out_channels': 1, 'hidden_dim': 32, 'channels_multipliers': [4, 4, 4, 2, 1], 'dropout_rate': 0.25}, '_target_': 'text_recognizer.networks.vqvae.vqvae.VQVAE', 'hidden_dim': 128, 'embedding_dim': 32, 'num_embeddings': 1024, 'decay': 0.99}\n"
+ ]
+ }
+ ],
"source": [
"# context initialization\n",
"with initialize(config_path=\"../training/conf/network/\", job_name=\"test_app\"):\n",
@@ -51,7 +96,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 11,
"id": "205a03e8-7aa1-407f-afa5-92693715b677",
"metadata": {},
"outputs": [],
@@ -61,7 +106,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 12,
"id": "c74384f0-754e-4c29-8f06-339372d6e4c1",
"metadata": {},
"outputs": [],
@@ -71,10 +116,66 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 13,
"id": "5ebab599-2497-42f8-b54b-1663ee66fde9",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "==========================================================================================\n",
+ "Layer (type:depth-idx) Output Shape Param #\n",
+ "==========================================================================================\n",
+ "├─Encoder: 1-1 [-1, 128, 18, 20] --\n",
+ "| └─Sequential: 2-1 [-1, 128, 18, 20] --\n",
+ "| | └─Conv2d: 3-1 [-1, 32, 576, 640] 320\n",
+ "| | └─Conv2d: 3-2 [-1, 32, 288, 320] 16,416\n",
+ "| | └─Mish: 3-3 [-1, 32, 288, 320] --\n",
+ "| | └─Conv2d: 3-4 [-1, 64, 144, 160] 32,832\n",
+ "| | └─Mish: 3-5 [-1, 64, 144, 160] --\n",
+ "| | └─Conv2d: 3-6 [-1, 128, 72, 80] 131,200\n",
+ "| | └─Mish: 3-7 [-1, 128, 72, 80] --\n",
+ "| | └─Conv2d: 3-8 [-1, 128, 36, 40] 262,272\n",
+ "| | └─Mish: 3-9 [-1, 128, 36, 40] --\n",
+ "| | └─Conv2d: 3-10 [-1, 128, 18, 20] 262,272\n",
+ "| | └─Mish: 3-11 [-1, 128, 18, 20] --\n",
+ "| | └─Residual: 3-12 [-1, 128, 18, 20] 164,352\n",
+ "| | └─Residual: 3-13 [-1, 128, 18, 20] 164,352\n",
+ "├─Conv2d: 1-2 [-1, 32, 18, 20] 4,128\n",
+ "├─VectorQuantizer: 1-3 [-1, 32, 18, 20] --\n",
+ "├─Conv2d: 1-4 [-1, 128, 18, 20] 4,224\n",
+ "├─Decoder: 1-5 [-1, 1, 576, 640] --\n",
+ "| └─Sequential: 2-2 [-1, 1, 576, 640] --\n",
+ "| | └─Residual: 3-14 [-1, 128, 18, 20] 164,352\n",
+ "| | └─Residual: 3-15 [-1, 128, 18, 20] 164,352\n",
+ "| | └─ConvTranspose2d: 3-16 [-1, 128, 36, 40] 262,272\n",
+ "| | └─Mish: 3-17 [-1, 128, 36, 40] --\n",
+ "| | └─ConvTranspose2d: 3-18 [-1, 128, 72, 80] 262,272\n",
+ "| | └─Mish: 3-19 [-1, 128, 72, 80] --\n",
+ "| | └─ConvTranspose2d: 3-20 [-1, 64, 144, 160] 131,136\n",
+ "| | └─Mish: 3-21 [-1, 64, 144, 160] --\n",
+ "| | └─ConvTranspose2d: 3-22 [-1, 32, 288, 320] 32,800\n",
+ "| | └─Mish: 3-23 [-1, 32, 288, 320] --\n",
+ "| | └─ConvTranspose2d: 3-24 [-1, 32, 576, 640] 16,416\n",
+ "| | └─Mish: 3-25 [-1, 32, 576, 640] --\n",
+ "| | └─Normalize: 3-26 [-1, 32, 576, 640] 64\n",
+ "| | └─Mish: 3-27 [-1, 32, 576, 640] --\n",
+ "| | └─Conv2d: 3-28 [-1, 1, 576, 640] 289\n",
+ "==========================================================================================\n",
+ "Total params: 2,076,321\n",
+ "Trainable params: 2,076,321\n",
+ "Non-trainable params: 0\n",
+ "Total mult-adds (G): 17.68\n",
+ "==========================================================================================\n",
+ "Input size (MB): 1.41\n",
+ "Forward/backward pass size (MB): 355.17\n",
+ "Params size (MB): 7.92\n",
+ "Estimated Total Size (MB): 364.49\n",
+ "==========================================================================================\n"
+ ]
+ }
+ ],
"source": [
"summary(net, (1, 576, 640), device=\"cpu\");"
]