{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "6ce2519f", "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['CUDA_VISIBLE_DEVICE'] = ''\n", "import random\n", "\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "\n", "import numpy as np\n", "\n", "%load_ext autoreload\n", "%autoreload 2\n", "\n", "from importlib.util import find_spec\n", "if find_spec(\"text_recognizer\") is None:\n", " import sys\n", " sys.path.append('..')\n", "\n", "from text_recognizer.data.iam_paragraphs import IAMParagraphs\n", "from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs\n", "from text_recognizer.data.iam_extended_paragraphs import IAMExtendedParagraphs" ] }, { "cell_type": "code", "execution_count": null, "id": "726ac25b", "metadata": {}, "outputs": [], "source": [ "def _plot(image, figsize=(12,12), title='', vmin=0, vmax=255):\n", " plt.figure(figsize=figsize)\n", " if title:\n", " plt.title(title)\n", " plt.imshow(image, cmap='gray', vmin=vmin, vmax=vmax)\n", "\n", "def convert_y_label_to_string(y, mapping, padding_index=3):\n", " return ''.join([mapping[i] for i in y if i != padding_index])" ] }, { "cell_type": "code", "execution_count": null, "id": "c6188bce", "metadata": { "scrolled": true }, "outputs": [], "source": [ "dataset = IAMExtendedParagraphs(batch_size=1, word_pieces=True)\n", "dataset.prepare_data()\n", "dataset.setup()\n", "print(dataset)" ] }, { "cell_type": "code", "execution_count": null, "id": "55b26b5d", "metadata": {}, "outputs": [], "source": [ "len(dataset.mapping)" ] }, { "cell_type": "code", "execution_count": null, "id": "42501428", "metadata": {}, "outputs": [], "source": [ "dataset = IAMParagraphs()\n", "dataset.prepare_data()\n", "dataset.setup()\n", "print(dataset)" ] }, { "cell_type": "code", "execution_count": null, "id": "e6e8c05b", "metadata": {}, "outputs": [], "source": [ "x, y = next(iter(dataset.test_dataloader()))" ] }, { "cell_type": "code", "execution_count": null, "id": "8bed2170", "metadata": {}, "outputs": [], "source": [ "x.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "0cf22683", "metadata": {}, "outputs": [], "source": [ "x, y = dataset.data_train[0]" ] }, { "cell_type": "code", "execution_count": null, "id": "8541e6ee", "metadata": {}, "outputs": [], "source": [ "x.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "40447ce6", "metadata": {}, "outputs": [], "source": [ "y" ] }, { "cell_type": "code", "execution_count": null, "id": "016e8c81", "metadata": {}, "outputs": [], "source": [ "len(y)" ] }, { "cell_type": "code", "execution_count": null, "id": "7aa8c021", "metadata": { "scrolled": true }, "outputs": [], "source": [ "x.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "7ef93252", "metadata": {}, "outputs": [], "source": [ "_plot(x[0], vmax=1, title=dataset.mapping.get_text(y))" ] }, { "cell_type": "code", "execution_count": null, "id": "6c62572f", "metadata": {}, "outputs": [], "source": [ "_plot(x[0, 0], vmax=1, title=convert_y_label_to_string(y[0], dataset.mapping))" ] }, { "cell_type": "code", "execution_count": null, "id": "e7778ae2", "metadata": {}, "outputs": [], "source": [ "# Training\n", "\n", "for _ in range(5):\n", " i = random.randint(0, len(dataset.data_train))\n", " x, y = dataset.data_train[i]\n", " _plot(x[0], vmax=1, title=convert_y_label_to_string(y, dataset.mapping))" ] }, { "cell_type": "code", "execution_count": null, "id": "dbf845a5", "metadata": {}, "outputs": [], "source": [ "from einops import rearrange" ] }, { "cell_type": "code", "execution_count": null, "id": "fe4bfb95", "metadata": {}, "outputs": [], "source": [ "x, y = dataset.data_train[2]" ] }, { "cell_type": "code", "execution_count": null, "id": "a0ba4dec", "metadata": {}, "outputs": [], "source": [ "_plot(x[0], vmax=1, title=convert_y_label_to_string(y, dataset.mapping))" ] }, { "cell_type": "code", "execution_count": null, "id": "34348d0e", "metadata": {}, "outputs": [], "source": [ "p = 32\n", "patches = rearrange(x.unsqueeze(0), 'b c (h p1) (w p2) -> b c (h w) p1 p2', p1 = p, p2 = p)" ] }, { "cell_type": "code", "execution_count": null, "id": "77bded74", "metadata": {}, "outputs": [], "source": [ "fig = plt.figure(figsize=(20, 20))\n", "for i in range(15):\n", " ax = fig.add_subplot(1, 15, i + 1)\n", " ax.imshow(patches[0, 0, i + 160, :, :].squeeze(0), cmap='gray')" ] }, { "cell_type": "code", "execution_count": null, "id": "9d11ca56", "metadata": {}, "outputs": [], "source": [ "# Testing\n", "\n", "for _ in range(5):\n", " i = random.randint(0, len(dataset.data_test))\n", " x, y = dataset.data_test[i]\n", " _plot(x[0], vmax=1, title=convert_y_label_to_string(y, dataset.mapping))" ] }, { "cell_type": "code", "execution_count": null, "id": "548d10da", "metadata": {}, "outputs": [], "source": [ "dataset = IAMSyntheticParagraphs()\n", "dataset.prepare_data()\n", "dataset.setup()\n", "print(dataset)" ] }, { "cell_type": "code", "execution_count": null, "id": "627730b5", "metadata": {}, "outputs": [], "source": [ "# Training\n", "\n", "for _ in range(5):\n", " i = random.randint(0, len(dataset.data_train))\n", " x, y = dataset.data_train[i]\n", " _plot(x[0], vmax=1, title=convert_y_label_to_string(y, dataset.mapping))" ] }, { "cell_type": "code", "execution_count": null, "id": "4150722e", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6" } }, "nbformat": 4, "nbformat_minor": 5 }