From 737000da5b44276512beffc1bdf81057df43ab2c Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 2 May 2021 22:27:42 +0200 Subject: Attention layer finished --- notebooks/00-testing-stuff-out.ipynb | 163 +++++++++++++++++++++++++++++++++++ 1 file changed, 163 insertions(+) (limited to 'notebooks') diff --git a/notebooks/00-testing-stuff-out.ipynb b/notebooks/00-testing-stuff-out.ipynb index 7c7b3a6..92faaf7 100644 --- a/notebooks/00-testing-stuff-out.ipynb +++ b/notebooks/00-testing-stuff-out.ipynb @@ -352,6 +352,169 @@ "t(datum, trg).shape" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "b, n = 16, 128\n", + "device = \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "x = lambda: torch.ones((b, n), device=device).bool()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([16, 128])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x().shape" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([16, 128])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.ones((b, n), device=device).bool().shape" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "x = torch.randn(1, 1, 576, 640)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "144" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "576 // 4" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "160" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "640 // 4" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "x = torch.randn(1, 1, 144, 160)" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "from einops import rearrange" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "patch_size=4\n", + "p = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 1440, 16])" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "p.shape" + ] + }, { "cell_type": "code", "execution_count": null, -- cgit v1.2.3-70-g09d2