From 7e8e54e84c63171e748bbf09516fd517e6821ace Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 20 Mar 2021 18:09:06 +0100 Subject: Inital commit for refactoring to lightning --- .../05-sanity-check-multihead-attention.ipynb | 169 --------------------- 1 file changed, 169 deletions(-) delete mode 100644 src/notebooks/05-sanity-check-multihead-attention.ipynb (limited to 'src/notebooks/05-sanity-check-multihead-attention.ipynb') diff --git a/src/notebooks/05-sanity-check-multihead-attention.ipynb b/src/notebooks/05-sanity-check-multihead-attention.ipynb deleted file mode 100644 index 54f0432..0000000 --- a/src/notebooks/05-sanity-check-multihead-attention.ipynb +++ /dev/null @@ -1,169 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "import cv2\n", - "%matplotlib inline\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import torch\n", - "from torch import nn\n", - "from importlib.util import find_spec\n", - "if find_spec(\"text_recognizer\") is None:\n", - " import sys\n", - " sys.path.append('..')\n", - "\n", - "from text_recognizer.networks.transformer.attention import MultiHeadAttention" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "temp_mha = MultiHeadAttention(hidden_dim=512, num_heads=8)\n", - "def print_out(Q, K, V):\n", - " temp_out, temp_attn = temp_mha.scaled_dot_product_attention(Q, K, V)\n", - " print('Attention weights are:', temp_attn.squeeze())\n", - " print('Output is:', temp_out.squeeze())" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "test_K = torch.tensor(\n", - " [[10, 0, 0],\n", - " [ 0,10, 0],\n", - " [ 0, 0,10],\n", - " [ 0, 0,10]]\n", - ").float()[None,None]\n", - "\n", - "test_V = torch.tensor(\n", - " [[ 1,0,0],\n", - " [ 10,0,0],\n", - " [ 100,5,0],\n", - " [1000,6,0]]\n", - ").float()[None,None]\n", - "\n", - "test_Q = torch.tensor(\n", - " [[0, 10, 0]]\n", - ").float()[None,None]\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Attention weights are: tensor([8.4333e-26, 1.0000e+00, 8.4333e-26, 8.4333e-26])\n", - "Output is: tensor([1.0000e+01, 9.2766e-25, 0.0000e+00])\n" - ] - } - ], - "source": [ - "print_out(test_Q, test_K, test_V)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Attends to the second element, as it should!" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Attention weights are: tensor([4.2166e-26, 4.2166e-26, 5.0000e-01, 5.0000e-01])\n", - "Output is: tensor([550.0000, 5.5000, 0.0000])\n" - ] - } - ], - "source": [ - "test_Q = torch.tensor([[0, 0, 10]]).float()[None,None]\n", - "print_out(test_Q, test_K, test_V)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Focuses equally on the third and fourth key." - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Attention weights are: tensor([[4.2166e-26, 4.2166e-26, 5.0000e-01, 5.0000e-01],\n", - " [8.4333e-26, 1.0000e+00, 8.4333e-26, 8.4333e-26],\n", - " [5.0000e-01, 5.0000e-01, 4.2166e-26, 4.2166e-26]])\n", - "Output is: tensor([[5.5000e+02, 5.5000e+00, 0.0000e+00],\n", - " [1.0000e+01, 9.2766e-25, 0.0000e+00],\n", - " [5.5000e+00, 4.6383e-25, 0.0000e+00]])\n" - ] - } - ], - "source": [ - "test_Q = torch.tensor(\n", - " [[0, 0, 10], [0, 10, 0], [10, 10, 0]]\n", - ").float()[None,None]\n", - "print_out(test_Q, test_K, test_V)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.4" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} -- cgit v1.2.3-70-g09d2