summaryrefslogtreecommitdiff
path: root/src/notebooks/05-sanity-check-multihead-attention.ipynb
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
commit7e8e54e84c63171e748bbf09516fd517e6821ace (patch)
tree996093f75a5d488dddf7ea1f159ed343a561ef89 /src/notebooks/05-sanity-check-multihead-attention.ipynb
parentb0719d84138b6bbe5f04a4982dfca673aea1a368 (diff)
Inital commit for refactoring to lightning
Diffstat (limited to 'src/notebooks/05-sanity-check-multihead-attention.ipynb')
-rw-r--r--src/notebooks/05-sanity-check-multihead-attention.ipynb169
1 files changed, 0 insertions, 169 deletions
diff --git a/src/notebooks/05-sanity-check-multihead-attention.ipynb b/src/notebooks/05-sanity-check-multihead-attention.ipynb
deleted file mode 100644
index 54f0432..0000000
--- a/src/notebooks/05-sanity-check-multihead-attention.ipynb
+++ /dev/null
@@ -1,169 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "%load_ext autoreload\n",
- "%autoreload 2\n",
- "\n",
- "import cv2\n",
- "%matplotlib inline\n",
- "import matplotlib.pyplot as plt\n",
- "import numpy as np\n",
- "import torch\n",
- "from torch import nn\n",
- "from importlib.util import find_spec\n",
- "if find_spec(\"text_recognizer\") is None:\n",
- " import sys\n",
- " sys.path.append('..')\n",
- "\n",
- "from text_recognizer.networks.transformer.attention import MultiHeadAttention"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "temp_mha = MultiHeadAttention(hidden_dim=512, num_heads=8)\n",
- "def print_out(Q, K, V):\n",
- " temp_out, temp_attn = temp_mha.scaled_dot_product_attention(Q, K, V)\n",
- " print('Attention weights are:', temp_attn.squeeze())\n",
- " print('Output is:', temp_out.squeeze())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "test_K = torch.tensor(\n",
- " [[10, 0, 0],\n",
- " [ 0,10, 0],\n",
- " [ 0, 0,10],\n",
- " [ 0, 0,10]]\n",
- ").float()[None,None]\n",
- "\n",
- "test_V = torch.tensor(\n",
- " [[ 1,0,0],\n",
- " [ 10,0,0],\n",
- " [ 100,5,0],\n",
- " [1000,6,0]]\n",
- ").float()[None,None]\n",
- "\n",
- "test_Q = torch.tensor(\n",
- " [[0, 10, 0]]\n",
- ").float()[None,None]\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Attention weights are: tensor([8.4333e-26, 1.0000e+00, 8.4333e-26, 8.4333e-26])\n",
- "Output is: tensor([1.0000e+01, 9.2766e-25, 0.0000e+00])\n"
- ]
- }
- ],
- "source": [
- "print_out(test_Q, test_K, test_V)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Attends to the second element, as it should!"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Attention weights are: tensor([4.2166e-26, 4.2166e-26, 5.0000e-01, 5.0000e-01])\n",
- "Output is: tensor([550.0000, 5.5000, 0.0000])\n"
- ]
- }
- ],
- "source": [
- "test_Q = torch.tensor([[0, 0, 10]]).float()[None,None]\n",
- "print_out(test_Q, test_K, test_V)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Focuses equally on the third and fourth key."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Attention weights are: tensor([[4.2166e-26, 4.2166e-26, 5.0000e-01, 5.0000e-01],\n",
- " [8.4333e-26, 1.0000e+00, 8.4333e-26, 8.4333e-26],\n",
- " [5.0000e-01, 5.0000e-01, 4.2166e-26, 4.2166e-26]])\n",
- "Output is: tensor([[5.5000e+02, 5.5000e+00, 0.0000e+00],\n",
- " [1.0000e+01, 9.2766e-25, 0.0000e+00],\n",
- " [5.5000e+00, 4.6383e-25, 0.0000e+00]])\n"
- ]
- }
- ],
- "source": [
- "test_Q = torch.tensor(\n",
- " [[0, 0, 10], [0, 10, 0], [10, 10, 0]]\n",
- ").float()[None,None]\n",
- "print_out(test_Q, test_K, test_V)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.4"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}