From 57525da0f267300792cd6b65e59914644a2dd39b Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 13 Jun 2021 23:58:56 +0200 Subject: Working on new efficient net impl --- notebooks/00-scratch-pad.ipynb | 326 +++++++---------------------------------- 1 file changed, 57 insertions(+), 269 deletions(-) (limited to 'notebooks/00-scratch-pad.ipynb') diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb index 4681360..4df7aa1 100644 --- a/notebooks/00-scratch-pad.ipynb +++ b/notebooks/00-scratch-pad.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -30,249 +30,52 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "True" - ] - }, - "execution_count": 2, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "torch.cuda.is_available()" ] }, { "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "decoder = Decoder(dim=64, depth=2, num_heads=4, ff_kwargs={}, attn_kwargs={}, cross_attend=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Decoder(\n", - " (layers): ModuleList(\n", - " (0): ModuleList(\n", - " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", - " (1): Attention(\n", - " (qkv_fn): Sequential(\n", - " (0): Linear(in_features=64, out_features=12288, bias=False)\n", - " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n", - " )\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (fc): Linear(in_features=4096, out_features=64, bias=True)\n", - " )\n", - " (2): Residual()\n", - " )\n", - " (1): ModuleList(\n", - " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", - " (1): Attention(\n", - " (qkv_fn): Sequential(\n", - " (0): Linear(in_features=64, out_features=12288, bias=False)\n", - " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n", - " )\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (fc): Linear(in_features=4096, out_features=64, bias=True)\n", - " )\n", - " (2): Residual()\n", - " )\n", - " (2): ModuleList(\n", - " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", - " (1): FeedForward(\n", - " (mlp): Sequential(\n", - " (0): GEGLU(\n", - " (fc): Linear(in_features=64, out_features=512, bias=True)\n", - " )\n", - " (1): Dropout(p=0.0, inplace=False)\n", - " (2): Linear(in_features=256, out_features=64, bias=True)\n", - " )\n", - " )\n", - " (2): Residual()\n", - " )\n", - " (3): ModuleList(\n", - " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", - " (1): Attention(\n", - " (qkv_fn): Sequential(\n", - " (0): Linear(in_features=64, out_features=12288, bias=False)\n", - " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n", - " )\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (fc): Linear(in_features=4096, out_features=64, bias=True)\n", - " )\n", - " (2): Residual()\n", - " )\n", - " (4): ModuleList(\n", - " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", - " (1): Attention(\n", - " (qkv_fn): Sequential(\n", - " (0): Linear(in_features=64, out_features=12288, bias=False)\n", - " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n", - " )\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (fc): Linear(in_features=4096, out_features=64, bias=True)\n", - " )\n", - " (2): Residual()\n", - " )\n", - " (5): ModuleList(\n", - " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", - " (1): FeedForward(\n", - " (mlp): Sequential(\n", - " (0): GEGLU(\n", - " (fc): Linear(in_features=64, out_features=512, bias=True)\n", - " )\n", - " (1): Dropout(p=0.0, inplace=False)\n", - " (2): Linear(in_features=256, out_features=64, bias=True)\n", - " )\n", - " )\n", - " (2): Residual()\n", - " )\n", - " )\n", - ")" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "decoder = Decoder(dim=128, depth=2, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "decoder.cuda()" ] }, { "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "transformer_decoder = Transformer(num_tokens=90, max_seq_len=690, attn_layers=decoder, emb_dim=64, emb_dropout=0.1)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Transformer(\n", - " (attn_layers): Decoder(\n", - " (layers): ModuleList(\n", - " (0): ModuleList(\n", - " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", - " (1): Attention(\n", - " (qkv_fn): Sequential(\n", - " (0): Linear(in_features=64, out_features=12288, bias=False)\n", - " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n", - " )\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (fc): Linear(in_features=4096, out_features=64, bias=True)\n", - " )\n", - " (2): Residual()\n", - " )\n", - " (1): ModuleList(\n", - " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", - " (1): Attention(\n", - " (qkv_fn): Sequential(\n", - " (0): Linear(in_features=64, out_features=12288, bias=False)\n", - " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n", - " )\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (fc): Linear(in_features=4096, out_features=64, bias=True)\n", - " )\n", - " (2): Residual()\n", - " )\n", - " (2): ModuleList(\n", - " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", - " (1): FeedForward(\n", - " (mlp): Sequential(\n", - " (0): GEGLU(\n", - " (fc): Linear(in_features=64, out_features=512, bias=True)\n", - " )\n", - " (1): Dropout(p=0.0, inplace=False)\n", - " (2): Linear(in_features=256, out_features=64, bias=True)\n", - " )\n", - " )\n", - " (2): Residual()\n", - " )\n", - " (3): ModuleList(\n", - " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", - " (1): Attention(\n", - " (qkv_fn): Sequential(\n", - " (0): Linear(in_features=64, out_features=12288, bias=False)\n", - " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n", - " )\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (fc): Linear(in_features=4096, out_features=64, bias=True)\n", - " )\n", - " (2): Residual()\n", - " )\n", - " (4): ModuleList(\n", - " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", - " (1): Attention(\n", - " (qkv_fn): Sequential(\n", - " (0): Linear(in_features=64, out_features=12288, bias=False)\n", - " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n", - " )\n", - " (dropout): Dropout(p=0.0, inplace=False)\n", - " (fc): Linear(in_features=4096, out_features=64, bias=True)\n", - " )\n", - " (2): Residual()\n", - " )\n", - " (5): ModuleList(\n", - " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", - " (1): FeedForward(\n", - " (mlp): Sequential(\n", - " (0): GEGLU(\n", - " (fc): Linear(in_features=64, out_features=512, bias=True)\n", - " )\n", - " (1): Dropout(p=0.0, inplace=False)\n", - " (2): Linear(in_features=256, out_features=64, bias=True)\n", - " )\n", - " )\n", - " (2): Residual()\n", - " )\n", - " )\n", - " )\n", - " (token_emb): Embedding(90, 64)\n", - " (emb_dropout): Dropout(p=0.1, inplace=False)\n", - " (pos_emb): AbsolutePositionalEmbedding(\n", - " (emb): Embedding(690, 64)\n", - " )\n", - " (project_emb): Identity()\n", - " (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n", - " (logits): Linear(in_features=64, out_features=90, bias=True)\n", - ")" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "transformer_decoder = Transformer(num_tokens=1000, max_seq_len=690, attn_layers=decoder, emb_dim=128, emb_dropout=0.1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "transformer_decoder.cuda()" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -286,21 +89,21 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "v = ViT(\n", " dim = 64,\n", " image_size = (576, 640),\n", - " patch_size = (64, 64),\n", + " patch_size = (32, 32),\n", " transformer = efficient_transformer\n", ").cuda()" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -309,7 +112,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -318,7 +121,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -327,60 +130,45 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([4, 90, 64])" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "o.shape" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([16, 690])" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "caption.shape" ] }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "o = torch.randn(16, 20 * 18, 128).cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([16, 690, 90])" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], + "source": [ + "caption = torch.randint(0, 1000, (16, 200)).cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], "source": [ "transformer_decoder(caption, context = o).shape # (1, 1024, 20000)" ] -- cgit v1.2.3-70-g09d2