{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "6ce2519f", "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['CUDA_VISIBLE_DEVICE'] = ''\n", "import random\n", "\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "\n", "import numpy as np\n", "\n", "%load_ext autoreload\n", "%autoreload 2\n", "\n", "from importlib.util import find_spec\n", "if find_spec(\"text_recognizer\") is None:\n", " import sys\n", " sys.path.append('..')\n", "\n", "from text_recognizer.data.iam_paragraphs import IAMParagraphs\n", "from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs\n", "from text_recognizer.data.iam_extended_paragraphs import IAMExtendedParagraphs" ] }, { "cell_type": "code", "execution_count": null, "id": "726ac25b", "metadata": {}, "outputs": [], "source": [ "def _plot(image, figsize=(12,12), title='', vmin=0, vmax=255):\n", " plt.figure(figsize=figsize)\n", " if title:\n", " plt.title(title)\n", " plt.imshow(image, cmap='gray', vmin=vmin, vmax=vmax)\n", "\n", "def convert_y_label_to_string(y, mapping, padding_index=3):\n", " return ''.join([mapping[int(i)] for i in y if i != padding_index])" ] }, { "cell_type": "code", "execution_count": null, "id": "ec16e41f-3d12-4da2-bf02-7429b41cf98e", "metadata": {}, "outputs": [], "source": [ "from hydra import compose, initialize\n", "from omegaconf import OmegaConf\n", "from hydra.utils import instantiate" ] }, { "cell_type": "code", "execution_count": null, "id": "3ffde4df-2c15-4f6f-ab24-b09e8e3a20c4", "metadata": {}, "outputs": [], "source": [ "import torch" ] }, { "cell_type": "code", "execution_count": null, "id": "e9386367-2b49-4633-9936-57081132e59e", "metadata": {}, "outputs": [], "source": [ "# context initialization\n", "with initialize(config_path=\"../training/conf/\", job_name=\"test_app\"):\n", " cfg = compose(config_name=\"config\", overrides=[\"+experiment=vqgan_htr_char\"])\n", " print(OmegaConf.to_yaml(cfg))" ] }, { "cell_type": "code", "execution_count": null, "id": "1c4624d1-6de5-41ab-9208-0988fcdba76d", "metadata": {}, "outputs": [], "source": [ "datamodule = instantiate(cfg.datamodule, mapping=cfg.mapping)\n", "datamodule.prepare_data()\n", "datamodule.setup()\n", "print(datamodule)" ] }, { "cell_type": "code", "execution_count": null, "id": "636b38d0-0fa1-4fc7-9737-6bb08b5b7a67", "metadata": {}, "outputs": [], "source": [ "net = instantiate(cfg.network).cuda()" ] }, { "cell_type": "code", "execution_count": null, "id": "82c4f950-d5f4-411d-bf57-7fc044e22e85", "metadata": {}, "outputs": [], "source": [ "x = torch.randn(2, 1, 576, 640).cuda()" ] }, { "cell_type": "code", "execution_count": null, "id": "9fa534e8-7934-49e5-be4d-975c4d2c60d9", "metadata": {}, "outputs": [], "source": [ "c = torch.randint(0, 53, (2, 682)).cuda()" ] }, { "cell_type": "code", "execution_count": null, "id": "e06f47b5-37db-427c-aa85-195e20989a32", "metadata": {}, "outputs": [], "source": [ "c.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "28af7b30-74f0-4dc5-b618-e37bd2f3965f", "metadata": {}, "outputs": [], "source": [ "net(x, c)" ] }, { "cell_type": "code", "execution_count": null, "id": "770f29f6-94f3-40c7-80f0-d85bd2d23fef", "metadata": {}, "outputs": [], "source": [ "len(datamodule.train_dataloader())" ] }, { "cell_type": "code", "execution_count": null, "id": "564d96b2-bf28-4a79-b350-261a934ea3d5", "metadata": {}, "outputs": [], "source": [ "x.min()" ] }, { "cell_type": "code", "execution_count": null, "id": "e6e8c05b", "metadata": {}, "outputs": [], "source": [ "x, y = next(iter(datamodule.train_dataloader()))" ] }, { "cell_type": "code", "execution_count": null, "id": "7a225225-f9bd-4643-b46d-910f42e79cce", "metadata": {}, "outputs": [], "source": [ "x.min()" ] }, { "cell_type": "code", "execution_count": null, "id": "4537b1ab-3465-4d19-8cd4-878328beb620", "metadata": {}, "outputs": [], "source": [ "x.max()" ] }, { "cell_type": "code", "execution_count": null, "id": "3caaf16a-9c3b-4c06-a1f0-2b9c9bfa3ae0", "metadata": {}, "outputs": [], "source": [ "from torch import nn" ] }, { "cell_type": "code", "execution_count": null, "id": "ded0a413-c0a8-430a-b903-11ffc18d5e08", "metadata": {}, "outputs": [], "source": [ "loss = nn.BCEWithLogitsLoss()" ] }, { "cell_type": "code", "execution_count": null, "id": "d0d8a42c-256a-48fa-96c6-b6225fba5057", "metadata": {}, "outputs": [], "source": [ "target" ] }, { "cell_type": "code", "execution_count": null, "id": "d427d975-9487-412b-9c43-67a32ae316e7", "metadata": {}, "outputs": [], "source": [ "input = 10 * torch.rand((6, 1, 576, 640), requires_grad=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "cfe20821-3830-4679-b7dd-c2ba509c839b", "metadata": {}, "outputs": [], "source": [ "s = nn.Softmax2d()" ] }, { "cell_type": "code", "execution_count": null, "id": "029de5b3-0de9-43b2-9be2-d98142e15057", "metadata": {}, "outputs": [], "source": [ "input.flatten(-2, -1).shape" ] }, { "cell_type": "code", "execution_count": null, "id": "3126984c-05c0-4766-812b-27b9b719e357", "metadata": {}, "outputs": [], "source": [ "s(input)" ] }, { "cell_type": "code", "execution_count": null, "id": "9b2b1837-f34a-474c-b74b-6e0f86127717", "metadata": {}, "outputs": [], "source": [ "input = torch.randn((8, 1, 576, 640), requires_grad=True)\n", "target = torch.empty((8, 1, 576, 640)).random_(2)\n", "output = loss(input, target)\n", "output.backward()" ] }, { "cell_type": "code", "execution_count": null, "id": "e0da6ebd-ff8a-41d6-9ea0-962f3487f71a", "metadata": {}, "outputs": [], "source": [ "output = loss(input.flatten(-2, -1), target.flatten(-2, -1))" ] }, { "cell_type": "code", "execution_count": null, "id": "c6be9d04-39f0-4344-afd7-a8ef38f4bce2", "metadata": {}, "outputs": [], "source": [ "output = loss(x, target)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "cb41aaf8-b5b8-4d7a-b1eb-3d6dd7eed7a1", "metadata": {}, "outputs": [], "source": [ "output" ] }, { "cell_type": "code", "execution_count": null, "id": "02b26e57-18ff-4a4e-8ebc-e18f8ba87e93", "metadata": {}, "outputs": [], "source": [ "output" ] }, { "cell_type": "code", "execution_count": null, "id": "8bed2170", "metadata": {}, "outputs": [], "source": [ "x.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "0cf22683", "metadata": {}, "outputs": [], "source": [ "x, y = datamodule.data_train[3]" ] }, { "cell_type": "code", "execution_count": null, "id": "074b269f-caff-4ec6-acdc-3f73721d5a05", "metadata": {}, "outputs": [], "source": [ "y" ] }, { "cell_type": "code", "execution_count": null, "id": "1e657891-45bb-479e-95ba-bdefe3a84ae9", "metadata": {}, "outputs": [], "source": [ "convert_y_label_to_string(y, datamodule.mapping, padding_index=3)" ] }, { "cell_type": "code", "execution_count": null, "id": "7aa8c021", "metadata": { "scrolled": true }, "outputs": [], "source": [ "x.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "7ef93252", "metadata": {}, "outputs": [], "source": [ "_plot(x[0], vmax=1, title=datamodule.mapping.get_text(y))" ] }, { "cell_type": "code", "execution_count": null, "id": "7d9119fc-8c8f-4697-bbdf-a982df34eba5", "metadata": {}, "outputs": [], "source": [ "x[0].max()" ] }, { "cell_type": "code", "execution_count": null, "id": "2986088b", "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(8,8))\n", "plt.imshow(x[0], cmap='gray')" ] }, { "cell_type": "code", "execution_count": null, "id": "3480ae5f-9cec-4814-98fe-02082a139add", "metadata": {}, "outputs": [], "source": [ "y[0]" ] }, { "cell_type": "code", "execution_count": null, "id": "6c62572f", "metadata": {}, "outputs": [], "source": [ "_plot(x[0, 0], vmax=1, title=convert_y_label_to_string(y[0], datamodule.mapping))" ] }, { "cell_type": "code", "execution_count": null, "id": "e7778ae2", "metadata": {}, "outputs": [], "source": [ "# Training\n", "\n", "for _ in range(5):\n", " i = random.randint(0, len(dataset.data_train))\n", " x, y = dataset.data_train[i]\n", " _plot(x[0], vmax=1, title=convert_y_label_to_string(y, dataset.mapping))" ] }, { "cell_type": "code", "execution_count": null, "id": "dbf845a5", "metadata": {}, "outputs": [], "source": [ "from einops import rearrange" ] }, { "cell_type": "code", "execution_count": null, "id": "fe4bfb95", "metadata": {}, "outputs": [], "source": [ "x, y = dataset.data_train[2]" ] }, { "cell_type": "code", "execution_count": null, "id": "a0ba4dec", "metadata": {}, "outputs": [], "source": [ "_plot(x[0], vmax=1, title=convert_y_label_to_string(y, dataset.mapping))" ] }, { "cell_type": "code", "execution_count": null, "id": "34348d0e", "metadata": {}, "outputs": [], "source": [ "p = 32\n", "patches = rearrange(x.unsqueeze(0), 'b c (h p1) (w p2) -> b c (h w) p1 p2', p1 = p, p2 = p)" ] }, { "cell_type": "code", "execution_count": null, "id": "77bded74", "metadata": {}, "outputs": [], "source": [ "fig = plt.figure(figsize=(20, 20))\n", "for i in range(15):\n", " ax = fig.add_subplot(1, 15, i + 1)\n", " ax.imshow(patches[0, 0, i + 160, :, :].squeeze(0), cmap='gray')" ] }, { "cell_type": "code", "execution_count": null, "id": "9d11ca56", "metadata": {}, "outputs": [], "source": [ "# Testing\n", "\n", "for _ in range(5):\n", " i = random.randint(0, len(dataset.data_test))\n", " x, y = dataset.data_test[i]\n", " _plot(x[0], vmax=1, title=convert_y_label_to_string(y, dataset.mapping))" ] }, { "cell_type": "code", "execution_count": null, "id": "548d10da", "metadata": {}, "outputs": [], "source": [ "dataset = IAMSyntheticParagraphs()\n", "dataset.prepare_data()\n", "dataset.setup()\n", "print(dataset)" ] }, { "cell_type": "code", "execution_count": null, "id": "627730b5", "metadata": {}, "outputs": [], "source": [ "# Training\n", "\n", "for _ in range(5):\n", " i = random.randint(0, len(dataset.data_train))\n", " x, y = dataset.data_train[i]\n", " _plot(x[0], vmax=1, title=convert_y_label_to_string(y, dataset.mapping))" ] }, { "cell_type": "code", "execution_count": null, "id": "4150722e", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.7" } }, "nbformat": 4, "nbformat_minor": 5 }