{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", "\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "from importlib.util import find_spec\n", "if find_spec(\"text_recognizer\") is None:\n", " import sys\n", " sys.path.append('..')\n", "\n", "from text_recognizer.data.emnist import EMNIST" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "EMNIST Dataset\n", "Num classes: 83\n", "Mapping: ['', '', '', '

', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', ' ', '!', '\"', '#', '&', \"'\", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '?']\n", "Dims: (1, 28, 28)\n", "Train/val/test sizes: 260276, 65070, 54028\n", "Batch x stats: (torch.Size([128, 1, 28, 28]), torch.float32, tensor(0.), tensor(0.1673), tensor(0.3277), tensor(1.))\n", "Batch y stats: (torch.Size([128]), torch.int64, tensor(4), tensor(65))\n", "\n" ] } ], "source": [ "data = EMNIST()\n", "data.prepare_data()\n", "data.setup()\n", "print(data)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([128, 1, 28, 28]) torch.float32 tensor(0.) tensor(0.2204) tensor(0.3593) tensor(1.)\n", "torch.Size([128]) torch.int64 tensor(4) tensor(4)\n" ] } ], "source": [ "x, y = next(iter(data.test_dataloader()))\n", "print(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())\n", "print(y.shape, y.dtype, y.min(), y.max())" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "

" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "fig = plt.figure(figsize=(9, 9))\n", "for i in range(9):\n", " ax = fig.add_subplot(3, 3, i + 1)\n", " rand_i = np.random.randint(len(data.data_test))\n", " image, label = data.data_test[rand_i]\n", " ax.imshow(image.reshape(28, 28), cmap='gray')\n", " ax.set_title(data.mapping[label])" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.6" } }, "nbformat": 4, "nbformat_minor": 4 }