summaryrefslogtreecommitdiff
path: root/notebooks/00-scratch-pad.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'notebooks/00-scratch-pad.ipynb')
-rw-r--r--notebooks/00-scratch-pad.ipynb326
1 files changed, 57 insertions, 269 deletions
diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb
index 4681360..4df7aa1 100644
--- a/notebooks/00-scratch-pad.ipynb
+++ b/notebooks/00-scratch-pad.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -30,249 +30,52 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "True"
- ]
- },
- "execution_count": 2,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"torch.cuda.is_available()"
]
},
{
"cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "decoder = Decoder(dim=64, depth=2, num_heads=4, ff_kwargs={}, attn_kwargs={}, cross_attend=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Decoder(\n",
- " (layers): ModuleList(\n",
- " (0): ModuleList(\n",
- " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
- " (1): Attention(\n",
- " (qkv_fn): Sequential(\n",
- " (0): Linear(in_features=64, out_features=12288, bias=False)\n",
- " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n",
- " )\n",
- " (dropout): Dropout(p=0.0, inplace=False)\n",
- " (fc): Linear(in_features=4096, out_features=64, bias=True)\n",
- " )\n",
- " (2): Residual()\n",
- " )\n",
- " (1): ModuleList(\n",
- " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
- " (1): Attention(\n",
- " (qkv_fn): Sequential(\n",
- " (0): Linear(in_features=64, out_features=12288, bias=False)\n",
- " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n",
- " )\n",
- " (dropout): Dropout(p=0.0, inplace=False)\n",
- " (fc): Linear(in_features=4096, out_features=64, bias=True)\n",
- " )\n",
- " (2): Residual()\n",
- " )\n",
- " (2): ModuleList(\n",
- " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
- " (1): FeedForward(\n",
- " (mlp): Sequential(\n",
- " (0): GEGLU(\n",
- " (fc): Linear(in_features=64, out_features=512, bias=True)\n",
- " )\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " (2): Linear(in_features=256, out_features=64, bias=True)\n",
- " )\n",
- " )\n",
- " (2): Residual()\n",
- " )\n",
- " (3): ModuleList(\n",
- " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
- " (1): Attention(\n",
- " (qkv_fn): Sequential(\n",
- " (0): Linear(in_features=64, out_features=12288, bias=False)\n",
- " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n",
- " )\n",
- " (dropout): Dropout(p=0.0, inplace=False)\n",
- " (fc): Linear(in_features=4096, out_features=64, bias=True)\n",
- " )\n",
- " (2): Residual()\n",
- " )\n",
- " (4): ModuleList(\n",
- " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
- " (1): Attention(\n",
- " (qkv_fn): Sequential(\n",
- " (0): Linear(in_features=64, out_features=12288, bias=False)\n",
- " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n",
- " )\n",
- " (dropout): Dropout(p=0.0, inplace=False)\n",
- " (fc): Linear(in_features=4096, out_features=64, bias=True)\n",
- " )\n",
- " (2): Residual()\n",
- " )\n",
- " (5): ModuleList(\n",
- " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
- " (1): FeedForward(\n",
- " (mlp): Sequential(\n",
- " (0): GEGLU(\n",
- " (fc): Linear(in_features=64, out_features=512, bias=True)\n",
- " )\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " (2): Linear(in_features=256, out_features=64, bias=True)\n",
- " )\n",
- " )\n",
- " (2): Residual()\n",
- " )\n",
- " )\n",
- ")"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "decoder = Decoder(dim=128, depth=2, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"decoder.cuda()"
]
},
{
"cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "transformer_decoder = Transformer(num_tokens=90, max_seq_len=690, attn_layers=decoder, emb_dim=64, emb_dropout=0.1)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Transformer(\n",
- " (attn_layers): Decoder(\n",
- " (layers): ModuleList(\n",
- " (0): ModuleList(\n",
- " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
- " (1): Attention(\n",
- " (qkv_fn): Sequential(\n",
- " (0): Linear(in_features=64, out_features=12288, bias=False)\n",
- " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n",
- " )\n",
- " (dropout): Dropout(p=0.0, inplace=False)\n",
- " (fc): Linear(in_features=4096, out_features=64, bias=True)\n",
- " )\n",
- " (2): Residual()\n",
- " )\n",
- " (1): ModuleList(\n",
- " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
- " (1): Attention(\n",
- " (qkv_fn): Sequential(\n",
- " (0): Linear(in_features=64, out_features=12288, bias=False)\n",
- " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n",
- " )\n",
- " (dropout): Dropout(p=0.0, inplace=False)\n",
- " (fc): Linear(in_features=4096, out_features=64, bias=True)\n",
- " )\n",
- " (2): Residual()\n",
- " )\n",
- " (2): ModuleList(\n",
- " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
- " (1): FeedForward(\n",
- " (mlp): Sequential(\n",
- " (0): GEGLU(\n",
- " (fc): Linear(in_features=64, out_features=512, bias=True)\n",
- " )\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " (2): Linear(in_features=256, out_features=64, bias=True)\n",
- " )\n",
- " )\n",
- " (2): Residual()\n",
- " )\n",
- " (3): ModuleList(\n",
- " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
- " (1): Attention(\n",
- " (qkv_fn): Sequential(\n",
- " (0): Linear(in_features=64, out_features=12288, bias=False)\n",
- " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n",
- " )\n",
- " (dropout): Dropout(p=0.0, inplace=False)\n",
- " (fc): Linear(in_features=4096, out_features=64, bias=True)\n",
- " )\n",
- " (2): Residual()\n",
- " )\n",
- " (4): ModuleList(\n",
- " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
- " (1): Attention(\n",
- " (qkv_fn): Sequential(\n",
- " (0): Linear(in_features=64, out_features=12288, bias=False)\n",
- " (1): Rearrange('b n (qkv h d) -> qkv b h n d', qkv=3, h=4)\n",
- " )\n",
- " (dropout): Dropout(p=0.0, inplace=False)\n",
- " (fc): Linear(in_features=4096, out_features=64, bias=True)\n",
- " )\n",
- " (2): Residual()\n",
- " )\n",
- " (5): ModuleList(\n",
- " (0): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
- " (1): FeedForward(\n",
- " (mlp): Sequential(\n",
- " (0): GEGLU(\n",
- " (fc): Linear(in_features=64, out_features=512, bias=True)\n",
- " )\n",
- " (1): Dropout(p=0.0, inplace=False)\n",
- " (2): Linear(in_features=256, out_features=64, bias=True)\n",
- " )\n",
- " )\n",
- " (2): Residual()\n",
- " )\n",
- " )\n",
- " )\n",
- " (token_emb): Embedding(90, 64)\n",
- " (emb_dropout): Dropout(p=0.1, inplace=False)\n",
- " (pos_emb): AbsolutePositionalEmbedding(\n",
- " (emb): Embedding(690, 64)\n",
- " )\n",
- " (project_emb): Identity()\n",
- " (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)\n",
- " (logits): Linear(in_features=64, out_features=90, bias=True)\n",
- ")"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "transformer_decoder = Transformer(num_tokens=1000, max_seq_len=690, attn_layers=decoder, emb_dim=128, emb_dropout=0.1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"transformer_decoder.cuda()"
]
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -286,21 +89,21 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"v = ViT(\n",
" dim = 64,\n",
" image_size = (576, 640),\n",
- " patch_size = (64, 64),\n",
+ " patch_size = (32, 32),\n",
" transformer = efficient_transformer\n",
").cuda()"
]
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -309,7 +112,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -318,7 +121,7 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -327,60 +130,45 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([4, 90, 64])"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"o.shape"
]
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([16, 690])"
- ]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"caption.shape"
]
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "o = torch.randn(16, 20 * 18, 128).cuda()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([16, 690, 90])"
- ]
- },
- "execution_count": 14,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
+ "source": [
+ "caption = torch.randint(0, 1000, (16, 200)).cuda()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
"source": [
"transformer_decoder(caption, context = o).shape # (1, 1024, 20000)"
]