summaryrefslogtreecommitdiff
path: root/notebooks/00-scratch-pad.ipynb
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-05 23:05:25 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-05 23:05:25 +0200
commit4d1f2cef39688871d2caafce42a09316381a27ae (patch)
tree0f4385969e7df6d7d313cd5910bde9a7475ca027 /notebooks/00-scratch-pad.ipynb
parentf0481decdad9afb52494e9e95996deef843ef233 (diff)
Refactor with attr, working on cnn+transformer network
Diffstat (limited to 'notebooks/00-scratch-pad.ipynb')
-rw-r--r--notebooks/00-scratch-pad.ipynb644
1 files changed, 588 insertions, 56 deletions
diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb
index 2ade2bb..16c6533 100644
--- a/notebooks/00-scratch-pad.ipynb
+++ b/notebooks/00-scratch-pad.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -30,106 +30,244 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
- "from text_recognizer.networks.encoders.efficientnet.efficientnet import EfficientNet"
+ "from pathlib import Path"
]
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {
- "scrolled": false
- },
+ "execution_count": 2,
+ "metadata": {},
"outputs": [],
"source": [
- "en = EfficientNet(\"b0\")"
+ "import attr"
]
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {
- "scrolled": false
- },
+ "execution_count": 9,
+ "metadata": {},
"outputs": [],
"source": [
- "summary(en, (1, 224, 224));"
+ "@attr.s\n",
+ "class B:\n",
+ " batch_size = attr.ib()\n",
+ " num_workers = attr.ib()"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
- "torch.cuda.is_available()"
+ "@attr.s\n",
+ "class T(B):\n",
+ "\n",
+ " def __attrs_post_init__(self) -> None:\n",
+ " super().__init__(self.batch_size, self.num_workers)\n",
+ " self.hej = None\n",
+ " \n",
+ " batch_size = attr.ib()\n",
+ " num_workers = attr.ib()\n",
+ " h: Path = attr.ib(converter=Path)"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
- "decoder = Decoder(dim=128, depth=2, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)"
+ "t = T(batch_size=16, num_workers=2, h=\"hej\")"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 13,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "PosixPath('hej')"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
- "decoder.cuda()"
+ "t.h"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 12,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "16"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
- "transformer_decoder = Transformer(num_tokens=1000, max_seq_len=690, attn_layers=decoder, emb_dim=128, emb_dropout=0.1)"
+ "t.batch_size"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 11,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "PosixPath('hej')"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
- "transformer_decoder.cuda()"
+ "t.h"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "../text_recognizer/__init__.py\n",
+ "../text_recognizer/callbacks/__init__.py\n",
+ "../text_recognizer/callbacks/wandb_callbacks.py\n",
+ "../text_recognizer/data/image_utils.py\n",
+ "../text_recognizer/data/emnist.py\n",
+ "../text_recognizer/data/iam_lines.py\n",
+ "../text_recognizer/data/download_utils.py\n",
+ "../text_recognizer/data/mappings.py\n",
+ "../text_recognizer/data/iam_preprocessor.py\n",
+ "../text_recognizer/data/__init__.py\n",
+ "../text_recognizer/data/make_wordpieces.py\n",
+ "../text_recognizer/data/iam_paragraphs.py\n",
+ "../text_recognizer/data/sentence_generator.py\n",
+ "../text_recognizer/data/emnist_lines.py\n",
+ "../text_recognizer/data/build_transitions.py\n",
+ "../text_recognizer/data/base_dataset.py\n",
+ "../text_recognizer/data/base_data_module.py\n",
+ "../text_recognizer/data/iam.py\n",
+ "../text_recognizer/data/iam_synthetic_paragraphs.py\n",
+ "../text_recognizer/data/transforms.py\n",
+ "../text_recognizer/data/iam_extended_paragraphs.py\n",
+ "../text_recognizer/networks/__init__.py\n",
+ "../text_recognizer/networks/util.py\n",
+ "../text_recognizer/networks/cnn_tranformer.py\n",
+ "../text_recognizer/networks/encoders/__init__.py\n",
+ "../text_recognizer/networks/encoders/efficientnet/efficientnet.py\n",
+ "../text_recognizer/networks/encoders/efficientnet/__init__.py\n",
+ "../text_recognizer/networks/encoders/efficientnet/utils.py\n",
+ "../text_recognizer/networks/encoders/efficientnet/mbconv.py\n",
+ "../text_recognizer/networks/loss/__init__.py\n",
+ "../text_recognizer/networks/loss/label_smoothing_loss.py\n",
+ "../text_recognizer/networks/vqvae/__init__.py\n",
+ "../text_recognizer/networks/vqvae/decoder.py\n",
+ "../text_recognizer/networks/vqvae/vqvae.py\n",
+ "../text_recognizer/networks/vqvae/vector_quantizer.py\n",
+ "../text_recognizer/networks/vqvae/encoder.py\n",
+ "../text_recognizer/networks/transformer/__init__.py\n",
+ "../text_recognizer/networks/transformer/layers.py\n",
+ "../text_recognizer/networks/transformer/residual.py\n",
+ "../text_recognizer/networks/transformer/attention.py\n",
+ "../text_recognizer/networks/transformer/transformer.py\n",
+ "../text_recognizer/networks/transformer/vit.py\n",
+ "../text_recognizer/networks/transformer/mlp.py\n",
+ "../text_recognizer/networks/transformer/norm.py\n",
+ "../text_recognizer/networks/transformer/positional_encodings/positional_encoding.py\n",
+ "../text_recognizer/networks/transformer/positional_encodings/__init__.py\n",
+ "../text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py\n",
+ "../text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py\n",
+ "../text_recognizer/networks/transformer/nystromer/__init__.py\n",
+ "../text_recognizer/networks/transformer/nystromer/nystromer.py\n",
+ "../text_recognizer/networks/transformer/nystromer/attention.py\n",
+ "../text_recognizer/models/__init__.py\n",
+ "../text_recognizer/models/base.py\n",
+ "../text_recognizer/models/vqvae.py\n",
+ "../text_recognizer/models/transformer.py\n",
+ "../text_recognizer/models/dino.py\n",
+ "../text_recognizer/models/metrics.py\n"
+ ]
+ }
+ ],
+ "source": [
+ "for f in Path(\"../text_recognizer\").glob(\"**/*.py\"):\n",
+ " print(f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "<generator object Path.glob at 0x7ff8bb9ce5f0>"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "Path(\"..\").glob(\"**/*.py\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
- "efficient_transformer = Nystromer(\n",
- " dim = 64,\n",
- " depth = 4,\n",
- " num_heads = 8,\n",
- " num_landmarks = 64\n",
- ")"
+ "from text_recognizer.networks.encoders.efficientnet.efficientnet import EfficientNet"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "en = EfficientNet(\"b0\")"
]
},
{
"cell_type": "code",
"execution_count": null,
- "metadata": {},
+ "metadata": {
+ "scrolled": false
+ },
"outputs": [],
"source": [
- "v = ViT(\n",
- " dim = 64,\n",
- " image_size = (576, 640),\n",
- " patch_size = (32, 32),\n",
- " transformer = efficient_transformer\n",
- ").cuda()"
+ "summary(en, (1, 224, 224));"
]
},
{
@@ -138,7 +276,7 @@
"metadata": {},
"outputs": [],
"source": [
- "t = torch.randn(8, 1, 576, 640).cuda()"
+ "torch.cuda.is_available()"
]
},
{
@@ -147,7 +285,7 @@
"metadata": {},
"outputs": [],
"source": [
- "en.cuda()"
+ "decoder = Decoder(dim=128, depth=2, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)"
]
},
{
@@ -156,7 +294,7 @@
"metadata": {},
"outputs": [],
"source": [
- "en(t)"
+ "decoder.cuda()"
]
},
{
@@ -165,7 +303,7 @@
"metadata": {},
"outputs": [],
"source": [
- "o = v(t)"
+ "transformer_decoder = Transformer(num_tokens=1003, max_seq_len=451, attn_layers=decoder, emb_dim=128, emb_dropout=0.1)"
]
},
{
@@ -174,7 +312,7 @@
"metadata": {},
"outputs": [],
"source": [
- "caption = torch.randint(0, 90, (16, 690)).cuda()"
+ "transformer_decoder.cuda()"
]
},
{
@@ -183,7 +321,12 @@
"metadata": {},
"outputs": [],
"source": [
- "o.shape"
+ "efficient_transformer = Nystromer(\n",
+ " dim = 64,\n",
+ " depth = 4,\n",
+ " num_heads = 8,\n",
+ " num_landmarks = 64\n",
+ ")"
]
},
{
@@ -192,16 +335,405 @@
"metadata": {},
"outputs": [],
"source": [
- "caption.shape"
+ "v = ViT(\n",
+ " dim = 64,\n",
+ " image_size = (576, 640),\n",
+ " patch_size = (32, 32),\n",
+ " transformer = efficient_transformer\n",
+ ").cuda()"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
- "o = torch.randn(16, 20 * 18, 128).cuda()"
+ "t = torch.randn(4, 1, 576, 640).cuda()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "EfficientNet(\n",
+ " (_conv_stem): Sequential(\n",
+ " (0): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)\n",
+ " (1): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)\n",
+ " (2): BatchNorm2d(32, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (3): Mish(inplace=True)\n",
+ " )\n",
+ " (_blocks): ModuleList(\n",
+ " (0): MBConvBlock(\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)\n",
+ " (1): BatchNorm2d(32, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(16, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (1): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(96, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), groups=96, bias=False)\n",
+ " (1): BatchNorm2d(96, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(24, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (2): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), groups=144, bias=False)\n",
+ " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(144, 36, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(36, 144, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(24, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (3): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(144, 144, kernel_size=(5, 5), stride=(2, 2), groups=144, bias=False)\n",
+ " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(144, 36, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(36, 144, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(144, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (4): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(240, 240, kernel_size=(5, 5), stride=(1, 1), groups=240, bias=False)\n",
+ " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(240, 60, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(60, 240, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (5): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(240, 240, kernel_size=(3, 3), stride=(2, 2), groups=240, bias=False)\n",
+ " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(240, 60, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(60, 240, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(240, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (6): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), groups=480, bias=False)\n",
+ " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(480, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (7): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), groups=480, bias=False)\n",
+ " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(480, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (8): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(480, 480, kernel_size=(5, 5), stride=(1, 1), groups=480, bias=False)\n",
+ " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(480, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (9): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(1, 1), groups=672, bias=False)\n",
+ " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (10): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(1, 1), groups=672, bias=False)\n",
+ " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (11): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(2, 2), groups=672, bias=False)\n",
+ " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(672, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (12): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n",
+ " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (13): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n",
+ " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (14): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n",
+ " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (15): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), groups=1152, bias=False)\n",
+ " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(1152, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(320, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (_conv_head): Sequential(\n",
+ " (0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(1280, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "en.cuda()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([4, 1280, 18, 20])"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "en(t).shape"
]
},
{
@@ -210,7 +742,7 @@
"metadata": {},
"outputs": [],
"source": [
- "caption = torch.randint(0, 1000, (16, 200)).cuda()"
+ "o = v(t)"
]
},
{
@@ -219,7 +751,7 @@
"metadata": {},
"outputs": [],
"source": [
- "transformer_decoder(caption, context = o).shape # (1, 1024, 20000)"
+ "caption = torch.randint(0, 90, (16, 690)).cuda()"
]
},
{
@@ -228,7 +760,7 @@
"metadata": {},
"outputs": [],
"source": [
- "from text_recognizer.networks.encoders.efficientnet.efficientnet import EfficientNet"
+ "o.shape"
]
},
{
@@ -237,7 +769,7 @@
"metadata": {},
"outputs": [],
"source": [
- "en = EfficientNet()"
+ "caption.shape"
]
},
{
@@ -246,7 +778,7 @@
"metadata": {},
"outputs": [],
"source": [
- "en.cuda()"
+ "o = torch.randn(16, 20 * 18, 128).cuda()"
]
},
{
@@ -255,7 +787,7 @@
"metadata": {},
"outputs": [],
"source": [
- "summary(en, (1, 576, 640))"
+ "caption = torch.randint(0, 1000, (16, 200)).cuda()"
]
},
{
@@ -264,7 +796,7 @@
"metadata": {},
"outputs": [],
"source": [
- "type(efficient_transformer)"
+ "transformer_decoder(caption, context = o).shape # (1, 1024, 20000)"
]
},
{