summaryrefslogtreecommitdiff
path: root/src/notebooks/Untitled.ipynb
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-10-22 22:45:58 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-10-22 22:45:58 +0200
commit4d7713746eb936832e84852e90292936b933e87d (patch)
tree2b2519d1d2ce53d4e1390590f52018d55dadbc7c /src/notebooks/Untitled.ipynb
parent1b3b8073a19f939d18a0bb85247eb0d99284f7cc (diff)
Transfomer added, many other changes.
Diffstat (limited to 'src/notebooks/Untitled.ipynb')
-rw-r--r--src/notebooks/Untitled.ipynb310
1 files changed, 310 insertions, 0 deletions
diff --git a/src/notebooks/Untitled.ipynb b/src/notebooks/Untitled.ipynb
new file mode 100644
index 0000000..76c4d28
--- /dev/null
+++ b/src/notebooks/Untitled.ipynb
@@ -0,0 +1,310 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The autoreload extension is already loaded. To reload it, use:\n",
+ " %reload_ext autoreload\n"
+ ]
+ }
+ ],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "import importlib\n",
+ "import cv2\n",
+ "import yaml\n",
+ "\n",
+ "%matplotlib inline\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import torch\n",
+ "from torch import nn\n",
+ "from importlib.util import find_spec\n",
+ "if find_spec(\"text_recognizer\") is None:\n",
+ " import sys\n",
+ " sys.path.append('..')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def convert_y_label_to_string(y, dataset=dataset):\n",
+ " return ''.join([dataset.mapper(int(i)) for i in y])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.models import VisionTransformerModel\n",
+ "from text_recognizer.datasets import IamLinesDataset\n",
+ "from text_recognizer.datasets.transforms import Compose, AddTokens"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 80,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "target_transform = Compose([torch.tensor, AddTokens(init_token=\"<sos>\", eos_token=\"<eos>\")])\n",
+ "dataset = IamLinesDataset(train=True, init_token=\"<sos>\", pad_token=\"_\", eos_token=\"<eos>\", target_transform=target_transform)\n",
+ "dataset.load_or_generate_data()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 55,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config_path = \"../training/experiments/VisionTransformerModel_IamLinesDataset_VisionTransformer/1021_083538/config.yml\"\n",
+ "with open(config_path, \"r\") as f:\n",
+ " experiment_config = yaml.safe_load(f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset_args = experiment_config.get(\"dataset\", {})\n",
+ "datasets_module = importlib.import_module(\"text_recognizer.datasets\")\n",
+ "dataset_ = getattr(datasets_module, dataset_args[\"type\"])\n",
+ "\n",
+ "network_module = importlib.import_module(\"text_recognizer.networks\")\n",
+ "network_fn_ = getattr(network_module, experiment_config[\"network\"][\"type\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2020-10-21 23:27:40.719 | DEBUG | text_recognizer.models.base:load_weights:454 - Loading network with pretrained weights.\n"
+ ]
+ }
+ ],
+ "source": [
+ "model = VisionTransformerModel(network_fn=network_fn_, dataset=dataset_, dataset_args=dataset_args)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2020-10-21 23:29:55.892 | DEBUG | text_recognizer.models.base:load_from_checkpoint:402 - Loading checkpoint...\n"
+ ]
+ }
+ ],
+ "source": [
+ "checkpoint_path = \"../training/experiments/VisionTransformerModel_IamLinesDataset_VisionTransformer/1021_083538/model/last.pt\"\n",
+ "model.load_from_checkpoint(checkpoint_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 40,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model.eval()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 90,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data, target = dataset[18]\n",
+ "sentence = convert_y_label_to_string(target, dataset) "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 91,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "([], [])"
+ ]
+ },
+ "execution_count": 91,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "image/png": "\n",
+ "text/plain": [
+ "<Figure size 1440x1440 with 1 Axes>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "plt.figure(figsize=(20, 20))\n",
+ "plt.title(sentence)\n",
+ "plt.imshow(data.squeeze(0).numpy(), cmap='gray')\n",
+ "plt.xticks([])\n",
+ "plt.yticks([])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 92,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "('Since 1958, 13 Labour life Peers and<eos>', 0.9999997615814209)"
+ ]
+ },
+ "execution_count": 92,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.predict_on_image(data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 95,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[[1, 28, 952], [92]]"
+ ]
+ },
+ "execution_count": 95,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "experiment_config[\"train_args\"][\"input_shape\"]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 99,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "=========================================================================================================\n",
+ "Layer (type:depth-idx) Output Shape Param #\n",
+ "=========================================================================================================\n",
+ "├─Sequential: 1-1 [-1, 158, 1, 28, 6] --\n",
+ "| └─Unfold: 2-1 [-1, 168, 158] --\n",
+ "| └─Rearrange: 2-2 [-1, 158, 1, 28, 6] --\n",
+ "├─Linear: 1-2 [-1, 158, 512] 86,528\n",
+ "├─PositionalEncoding: 1-3 [-1, 158, 512] --\n",
+ "| └─Dropout: 2-3 [-1, 158, 512] --\n",
+ "├─Embedding: 1-4 [-1, 92, 512] 41,984\n",
+ "├─PositionalEncoding: 1-5 [-1, 92, 512] --\n",
+ "| └─Dropout: 2-4 [-1, 92, 512] --\n",
+ "├─Transformer: 1-6 [-1, 92, 512] --\n",
+ "| └─Encoder: 2-5 [-1, 158, 512] --\n",
+ "| | └─ModuleList: 3 [] --\n",
+ "| | | └─EncoderLayer: 4-1 [-1, 158, 512] 3,150,848\n",
+ "| | | └─EncoderLayer: 4-2 [-1, 158, 512] 3,150,848\n",
+ "| | | └─EncoderLayer: 4-3 [-1, 158, 512] 3,150,848\n",
+ "| | | └─EncoderLayer: 4-4 [-1, 158, 512] 3,150,848\n",
+ "| | └─LayerNorm: 3-1 [-1, 158, 512] 1,024\n",
+ "| └─Decoder: 2-6 [-1, 92, 512] --\n",
+ "| | └─ModuleList: 3 [] --\n",
+ "| | | └─DecoderLayer: 4-5 [-1, 92, 512] 4,200,960\n",
+ "| | | └─DecoderLayer: 4-6 [-1, 92, 512] 4,200,960\n",
+ "| | | └─DecoderLayer: 4-7 [-1, 92, 512] 4,200,960\n",
+ "| | | └─DecoderLayer: 4-8 [-1, 92, 512] 4,200,960\n",
+ "| | └─LayerNorm: 3-2 [-1, 92, 512] 1,024\n",
+ "├─Sequential: 1-7 [-1, 92, 82] --\n",
+ "| └─LayerNorm: 2-7 [-1, 92, 512] 1,024\n",
+ "| └─Linear: 2-8 [-1, 92, 512] 262,656\n",
+ "| └─GELU: 2-9 [-1, 92, 512] --\n",
+ "| └─Dropout: 2-10 [-1, 92, 512] --\n",
+ "| └─Linear: 2-11 [-1, 92, 82] 42,066\n",
+ "=========================================================================================================\n",
+ "Total params: 29,843,538\n",
+ "Trainable params: 29,843,538\n",
+ "Non-trainable params: 0\n",
+ "Total mult-adds (M): 118.22\n",
+ "=========================================================================================================\n",
+ "Input size (MB): 0.10\n",
+ "Forward/backward pass size (MB): 2.73\n",
+ "Params size (MB): 113.84\n",
+ "Estimated Total Size (MB): 116.68\n",
+ "=========================================================================================================\n"
+ ]
+ }
+ ],
+ "source": [
+ "model.summary(experiment_config[\"train_args\"][\"input_shape\"], 4)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.4"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}