1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
|
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"%matplotlib inline\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"\n",
"from importlib.util import find_spec\n",
"if find_spec(\"text_recognizer\") is None:\n",
" import sys\n",
" sys.path.append('..')\n",
"\n",
"from text_recognizer.data.emnist import EMNIST"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"EMNIST Dataset\n",
"Num classes: 83\n",
"Mapping: ['<b>', '<s>', '</s>', '<p>', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', ' ', '!', '\"', '#', '&', \"'\", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '?']\n",
"Dims: (1, 28, 28)\n",
"Train/val/test sizes: 260276, 65070, 54028\n",
"Batch x stats: (torch.Size([128, 1, 28, 28]), torch.float32, tensor(0.), tensor(0.1673), tensor(0.3277), tensor(1.))\n",
"Batch y stats: (torch.Size([128]), torch.int64, tensor(4), tensor(65))\n",
"\n"
]
}
],
"source": [
"data = EMNIST()\n",
"data.prepare_data()\n",
"data.setup()\n",
"print(data)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([128, 1, 28, 28]) torch.float32 tensor(0.) tensor(0.2204) tensor(0.3593) tensor(1.)\n",
"torch.Size([128]) torch.int64 tensor(4) tensor(4)\n"
]
}
],
"source": [
"x, y = next(iter(data.test_dataloader()))\n",
"print(x.shape, x.dtype, x.min(), x.mean(), x.std(), x.max())\n",
"print(y.shape, y.dtype, y.min(), y.max())"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 648x648 with 9 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"fig = plt.figure(figsize=(9, 9))\n",
"for i in range(9):\n",
" ax = fig.add_subplot(3, 3, i + 1)\n",
" rand_i = np.random.randint(len(data.data_test))\n",
" image, label = data.data_test[rand_i]\n",
" ax.imshow(image.reshape(28, 28), cmap='gray')\n",
" ax.set_title(data.mapping[label])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
|