summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-09-09 23:31:31 +0200
committeraktersnurra <gustaf.rydholm@gmail.com>2020-09-09 23:31:31 +0200
commit2b63fd952bdc9c7c72edd501cbcdbf3231e98f00 (patch)
tree1c0e0898cb8b66faff9e5d410aa1f82d13542f68 /src
parente1b504bca41a9793ed7e88ef14f2e2cbd85724f2 (diff)
Created an abstract Dataset class for common methods.
Diffstat (limited to 'src')
-rw-r--r--src/notebooks/01-look-at-emnist.ipynb134
-rw-r--r--src/notebooks/01b-dataset_normalization.ipynb148
-rw-r--r--src/notebooks/02b-emnist-lines-dataset.ipynb124
-rw-r--r--src/notebooks/03a-line-prediction.ipynb31
-rw-r--r--src/notebooks/04a-look-at-iam-lines.ipynb101
-rw-r--r--src/notebooks/04b-look-at-iam-paragraphs.ipynb (renamed from src/notebooks/04-look-at-iam-paragraphs.ipynb)26
-rw-r--r--src/text_recognizer/datasets/__init__.py16
-rw-r--r--src/text_recognizer/datasets/dataset.py124
-rw-r--r--src/text_recognizer/datasets/emnist_dataset.py228
-rw-r--r--src/text_recognizer/datasets/emnist_lines_dataset.py56
-rw-r--r--src/text_recognizer/datasets/iam_dataset.py6
-rw-r--r--src/text_recognizer/datasets/iam_lines_dataset.py68
-rw-r--r--src/text_recognizer/datasets/iam_paragraphs_dataset.py70
-rw-r--r--src/text_recognizer/datasets/sentence_generator.py2
-rw-r--r--src/text_recognizer/datasets/util.py125
-rw-r--r--src/text_recognizer/models/base.py2
-rw-r--r--src/text_recognizer/networks/ctc.py2
17 files changed, 582 insertions, 681 deletions
diff --git a/src/notebooks/01-look-at-emnist.ipynb b/src/notebooks/01-look-at-emnist.ipynb
index 93083a5..564d14e 100644
--- a/src/notebooks/01-look-at-emnist.ipynb
+++ b/src/notebooks/01-look-at-emnist.ipynb
@@ -2,9 +2,18 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 18,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The autoreload extension is already loaded. To reload it, use:\n",
+ " %reload_ext autoreload\n"
+ ]
+ }
+ ],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
@@ -22,7 +31,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
@@ -31,7 +40,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 35,
"metadata": {},
"outputs": [],
"source": [
@@ -40,7 +49,16 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 36,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset.load_or_generate_data()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
@@ -49,7 +67,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 38,
"metadata": {},
"outputs": [],
"source": [
@@ -58,7 +76,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 39,
"metadata": {},
"outputs": [
{
@@ -67,7 +85,7 @@
"55898"
]
},
- "execution_count": 10,
+ "execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
@@ -78,7 +96,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
@@ -87,7 +105,7 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 41,
"metadata": {},
"outputs": [
{
@@ -96,7 +114,7 @@
"3494"
]
},
- "execution_count": 19,
+ "execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
@@ -107,19 +125,74 @@
},
{
"cell_type": "code",
- "execution_count": 14,
+ "execution_count": 42,
"metadata": {},
"outputs": [
{
- "ename": "ValueError",
- "evalue": "only one element tensors can be converted to Python scalars",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m<ipython-input-14-69c3b5027f10>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0md1\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mitem\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;31mValueError\u001b[0m: only one element tensors can be converted to Python scalars"
- ]
+ "data": {
+ "text/plain": [
+ "tensor([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
+ " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
+ " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
+ " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 4, 4, 4, 4, 4, 2, 0, 0, 0, 0],\n",
+ " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 2, 4, 9, 32, 37, 37, 37, 32, 20, 1, 0, 0, 0],\n",
+ " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 3, 65, 109, 140, 204, 215, 217, 217, 201, 154, 22, 0, 0, 0],\n",
+ " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3,\n",
+ " 12, 122, 190, 222, 245, 249, 250, 250, 242, 206, 46, 0, 0, 0],\n",
+ " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 8, 79,\n",
+ " 127, 222, 247, 253, 235, 228, 249, 254, 254, 245, 114, 4, 0, 0],\n",
+ " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 35, 91, 219,\n",
+ " 244, 252, 247, 207, 100, 84, 223, 251, 254, 250, 127, 4, 0, 0],\n",
+ " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 9, 91, 163, 246,\n",
+ " 252, 244, 220, 127, 39, 48, 218, 250, 255, 250, 127, 4, 0, 0],\n",
+ " [ 0, 0, 0, 0, 0, 0, 0, 0, 5, 20, 95, 219, 246, 246,\n",
+ " 221, 127, 79, 10, 5, 37, 217, 250, 254, 249, 125, 4, 0, 0],\n",
+ " [ 0, 0, 0, 0, 0, 0, 0, 0, 20, 67, 175, 246, 252, 219,\n",
+ " 164, 47, 22, 1, 5, 39, 218, 250, 254, 245, 114, 4, 0, 0],\n",
+ " [ 0, 0, 0, 0, 0, 0, 1, 9, 95, 175, 250, 246, 219, 91,\n",
+ " 35, 1, 0, 0, 22, 84, 234, 252, 250, 220, 50, 0, 0, 0],\n",
+ " [ 0, 0, 0, 0, 0, 0, 9, 35, 164, 221, 252, 219, 163, 35,\n",
+ " 9, 0, 0, 0, 46, 127, 246, 254, 245, 204, 34, 0, 0, 0],\n",
+ " [ 0, 0, 0, 0, 0, 7, 91, 163, 246, 252, 219, 91, 35, 1,\n",
+ " 0, 0, 0, 10, 128, 209, 254, 254, 220, 139, 9, 0, 0, 0],\n",
+ " [ 0, 0, 0, 0, 3, 22, 159, 219, 252, 247, 164, 35, 9, 0,\n",
+ " 0, 0, 1, 36, 175, 233, 254, 254, 204, 115, 4, 0, 0, 0],\n",
+ " [ 0, 0, 0, 1, 36, 95, 232, 251, 232, 195, 47, 1, 0, 0,\n",
+ " 0, 9, 35, 163, 246, 253, 249, 232, 122, 45, 0, 0, 0, 0],\n",
+ " [ 0, 0, 0, 7, 91, 164, 247, 251, 187, 127, 20, 0, 0, 0,\n",
+ " 1, 35, 91, 219, 253, 254, 234, 187, 67, 20, 0, 0, 0, 0],\n",
+ " [ 0, 0, 0, 46, 207, 244, 247, 220, 80, 24, 1, 3, 8, 34,\n",
+ " 52, 164, 219, 253, 249, 234, 155, 79, 4, 0, 0, 0, 0, 0],\n",
+ " [ 0, 0, 2, 81, 232, 251, 235, 179, 39, 12, 5, 22, 46, 115,\n",
+ " 139, 221, 246, 254, 234, 188, 79, 32, 0, 0, 0, 0, 0, 0],\n",
+ " [ 0, 0, 3, 112, 244, 254, 236, 193, 130, 127, 129, 173, 209, 245,\n",
+ " 250, 254, 253, 232, 154, 79, 4, 0, 0, 0, 0, 0, 0, 0],\n",
+ " [ 0, 0, 0, 46, 206, 242, 249, 248, 249, 250, 250, 250, 250, 250,\n",
+ " 250, 243, 219, 95, 22, 7, 0, 0, 0, 0, 0, 0, 0, 0],\n",
+ " [ 0, 0, 0, 22, 154, 201, 217, 222, 245, 249, 249, 233, 222, 217,\n",
+ " 217, 202, 158, 36, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
+ " [ 0, 0, 0, 1, 20, 32, 39, 51, 114, 125, 125, 82, 51, 37,\n",
+ " 37, 32, 20, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
+ " [ 0, 0, 0, 0, 2, 4, 5, 9, 32, 37, 37, 21, 9, 4,\n",
+ " 4, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
+ " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
+ " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
+ " [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
+ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],\n",
+ " dtype=torch.uint8)"
+ ]
+ },
+ "execution_count": 42,
+ "metadata": {},
+ "output_type": "execute_result"
}
],
"source": [
@@ -128,7 +201,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 43,
"metadata": {},
"outputs": [
{
@@ -137,7 +210,7 @@
"torch.Tensor"
]
},
- "execution_count": 4,
+ "execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
@@ -148,7 +221,7 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 44,
"metadata": {},
"outputs": [
{
@@ -169,7 +242,7 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
@@ -187,7 +260,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 46,
"metadata": {},
"outputs": [
{
@@ -207,7 +280,7 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 47,
"metadata": {},
"outputs": [
{
@@ -238,6 +311,13 @@
"metadata": {},
"outputs": [],
"source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
}
],
"metadata": {
diff --git a/src/notebooks/01b-dataset_normalization.ipynb b/src/notebooks/01b-dataset_normalization.ipynb
deleted file mode 100644
index 9421816..0000000
--- a/src/notebooks/01b-dataset_normalization.ipynb
+++ /dev/null
@@ -1,148 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
- "source": [
- "%load_ext autoreload\n",
- "%autoreload 2\n",
- "\n",
- "%matplotlib inline\n",
- "import matplotlib.pyplot as plt\n",
- "import numpy as np\n",
- "from PIL import Image\n",
- "import torch\n",
- "from importlib.util import find_spec\n",
- "if find_spec(\"text_recognizer\") is None:\n",
- " import sys\n",
- " sys.path.append('..')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "from text_recognizer.datasets import EmnistDataLoader"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [],
- "source": [
- "data_loaders = EmnistDataLoader(splits=[\"train\"], sample_to_balance=True,\n",
- " subsample_fraction = None,\n",
- " transform = None,\n",
- " target_transform = None,\n",
- " batch_size = 512,\n",
- " shuffle = True,\n",
- " num_workers = 0,\n",
- " cuda = False,\n",
- " seed = 4711)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [],
- "source": [
- "loader = data_loaders(\"train\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [],
- "source": [
- "mean = 0.\n",
- "std = 0.\n",
- "nb_samples = 0.\n",
- "for data in loader:\n",
- " data, _ = data\n",
- " batch_samples = data.size(0)\n",
- " data = data.view(batch_samples, data.size(1), -1)\n",
- " mean += data.mean(2).sum(0)\n",
- " std += data.std(2).sum(0)\n",
- " nb_samples += batch_samples\n",
- "\n",
- "mean /= nb_samples\n",
- "std /= nb_samples"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([0.1731])"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "mean"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([0.3247])"
- ]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "std"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.2"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
diff --git a/src/notebooks/02b-emnist-lines-dataset.ipynb b/src/notebooks/02b-emnist-lines-dataset.ipynb
index a7aabeb..2ef7da7 100644
--- a/src/notebooks/02b-emnist-lines-dataset.ipynb
+++ b/src/notebooks/02b-emnist-lines-dataset.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -31,61 +31,43 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
- "emnist_train = EmnistDataset(train=True, sample_to_balance=True)\n",
- "emnist_val = EmnistDataset(train=False, sample_to_balance=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "2020-08-23 22:01:45.373 | DEBUG | text_recognizer.datasets.emnist_lines_dataset:_load_data:162 - EmnistLinesDataset loading data from HDF5...\n"
- ]
- }
- ],
- "source": [
"emnist_lines = EmnistLinesDataset(train=False)"
]
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "2020-08-23 22:01:46.598 | DEBUG | text_recognizer.datasets.emnist_lines_dataset:_load_data:162 - EmnistLinesDataset loading data from HDF5...\n"
+ "2020-09-09 23:07:57.716 | DEBUG | text_recognizer.datasets.emnist_lines_dataset:_load_data:134 - EmnistLinesDataset loading data from HDF5...\n"
]
}
],
"source": [
- "emnist_lines._load_or_generate_data()"
+ "emnist_lines.load_or_generate_data()"
]
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"def convert_y_label_to_string(y, emnist_lines=emnist_lines):\n",
- " return ''.join([emnist_lines.mapping[i] for i in y])"
+ " return ''.join([emnist_lines.mapper(i) for i in y])"
]
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 16,
"metadata": {
"scrolled": false
},
@@ -230,7 +212,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": 17,
"metadata": {},
"outputs": [
{
@@ -255,94 +237,6 @@
},
{
"cell_type": "code",
- "execution_count": 14,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "2020-08-05 00:40:26.070 | DEBUG | text_recognizer.datasets.emnist_lines_dataset:_load_data:153 - EmnistLinesDataset loading data from HDF5...\n"
- ]
- }
- ],
- "source": [
- "dl = EmnistLinesDataLoaders(\"train\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "metadata": {},
- "outputs": [],
- "source": [
- "ddl = dl(\"train\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "metadata": {},
- "outputs": [],
- "source": [
- "batch = next(iter(ddl))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 21,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([1, 28, 952])"
- ]
- },
- "execution_count": 21,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "batch[0][0].shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 24,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<matplotlib.image.AxesImage at 0x7f139b1cf1c0>"
- ]
- },
- "execution_count": 24,
- "metadata": {},
- "output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 1440x1440 with 1 Axes>"
- ]
- },
- "metadata": {
- "needs_background": "light"
- },
- "output_type": "display_data"
- }
- ],
- "source": [
- "plt.figure(figsize=(20, 20))\n",
- "plt.imshow(batch[0][-1].squeeze(0), cmap='gray')"
- ]
- },
- {
- "cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
diff --git a/src/notebooks/03a-line-prediction.ipynb b/src/notebooks/03a-line-prediction.ipynb
index 65c6dd6..336614f 100644
--- a/src/notebooks/03a-line-prediction.ipynb
+++ b/src/notebooks/03a-line-prediction.ipynb
@@ -49,7 +49,7 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2020-09-01 23:37:29.664 | DEBUG | text_recognizer.datasets.emnist_lines_dataset:_load_data:164 - EmnistLinesDataset loading data from HDF5...\n"
+ "2020-09-09 20:38:27.854 | DEBUG | text_recognizer.datasets.emnist_lines_dataset:_load_data:164 - EmnistLinesDataset loading data from HDF5...\n"
]
}
],
@@ -71,6 +71,35 @@
"cell_type": "code",
"execution_count": 6,
"metadata": {},
+ "outputs": [],
+ "source": [
+ "data, target = emnist_lines[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([34])"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "target.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
"outputs": [
{
"name": "stderr",
diff --git a/src/notebooks/04a-look-at-iam-lines.ipynb b/src/notebooks/04a-look-at-iam-lines.ipynb
index aa62d19..0f9fefb 100644
--- a/src/notebooks/04a-look-at-iam-lines.ipynb
+++ b/src/notebooks/04a-look-at-iam-lines.ipynb
@@ -2,18 +2,9 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": 1,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "The autoreload extension is already loaded. To reload it, use:\n",
- " %reload_ext autoreload\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
@@ -32,7 +23,7 @@
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -41,7 +32,7 @@
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -65,7 +56,7 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
@@ -74,7 +65,7 @@
"(97, 80)"
]
},
- "execution_count": 16,
+ "execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
@@ -85,7 +76,7 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
@@ -94,7 +85,7 @@
"'A MOVE to stop Mr. Gaitskell from________________________________________________________________'"
]
},
- "execution_count": 17,
+ "execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@@ -108,7 +99,7 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
@@ -260,6 +251,80 @@
},
{
"cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "data, target = dataset[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 28, 952])"
+ ]
+ },
+ "execution_count": 9,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "data.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([97])"
+ ]
+ },
+ "execution_count": 10,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "target.shape\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([10, 62, 22, 24, 31, 14, 62, 55, 50, 62, 54, 55, 50, 51, 62, 22, 53, 74,\n",
+ " 62, 16, 36, 44, 55, 54, 46, 40, 47, 47, 62, 41, 53, 50, 48, 79, 79, 79,\n",
+ " 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,\n",
+ " 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,\n",
+ " 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79, 79,\n",
+ " 79, 79, 79, 79, 79, 79, 79], dtype=torch.uint8)"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "target"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
diff --git a/src/notebooks/04-look-at-iam-paragraphs.ipynb b/src/notebooks/04b-look-at-iam-paragraphs.ipynb
index da420b0..a442420 100644
--- a/src/notebooks/04-look-at-iam-paragraphs.ipynb
+++ b/src/notebooks/04b-look-at-iam-paragraphs.ipynb
@@ -2,9 +2,18 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": 4,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "The autoreload extension is already loaded. To reload it, use:\n",
+ " %reload_ext autoreload\n"
+ ]
+ }
+ ],
"source": [
"\n",
"\n",
@@ -28,7 +37,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 5,
"metadata": {},
"outputs": [
{
@@ -48,19 +57,14 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "2020-09-08 23:04:48.663 | INFO | text_recognizer.datasets.iam_paragraphs_dataset:_decide_on_crop_dims:190 - Max crop width and height were found to be 1240x1156.\n",
- "2020-09-08 23:04:48.664 | INFO | text_recognizer.datasets.iam_paragraphs_dataset:_decide_on_crop_dims:193 - Setting them to 1240x1240\n",
- "2020-09-08 23:04:48.665 | INFO | text_recognizer.datasets.iam_paragraphs_dataset:_process_iam_paragraphs:161 - Cropping paragraphs, generating ground truth, and saving debugging images to /home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/data/interim/iam_paragraphs/debug_crops\n",
- "2020-09-08 23:05:10.585 | ERROR | text_recognizer.datasets.iam_paragraphs_dataset:_crop_paragraph_image:240 - Rescued /home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/data/raw/iam/iamdb/forms/e01-086.jpg: could not broadcast input array from shape (687,1236) into shape (687,1240)\n",
- "2020-09-08 23:05:14.430 | ERROR | text_recognizer.datasets.iam_paragraphs_dataset:_crop_paragraph_image:240 - Rescued /home/akternurra/Documents/projects/quest-for-general-artifical-intelligence/projects/text-recognizer/data/raw/iam/iamdb/forms/e01-081.jpg: could not broadcast input array from shape (587,1236) into shape (587,1240)\n",
- "2020-09-08 23:05:29.910 | INFO | text_recognizer.datasets.iam_paragraphs_dataset:_load_iam_paragraphs:278 - Loading IAM paragraph crops and ground truth from image files...\n"
+ "2020-09-09 23:24:01.352 | INFO | text_recognizer.datasets.iam_paragraphs_dataset:_load_iam_paragraphs:244 - Loading IAM paragraph crops and ground truth from image files...\n"
]
},
{
@@ -83,7 +87,7 @@
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
diff --git a/src/text_recognizer/datasets/__init__.py b/src/text_recognizer/datasets/__init__.py
index ede4541..a3af9b1 100644
--- a/src/text_recognizer/datasets/__init__.py
+++ b/src/text_recognizer/datasets/__init__.py
@@ -1,10 +1,5 @@
"""Dataset modules."""
-from .emnist_dataset import (
- DATA_DIRNAME,
- EmnistDataset,
- EmnistMapper,
- ESSENTIALS_FILENAME,
-)
+from .emnist_dataset import EmnistDataset, Transpose
from .emnist_lines_dataset import (
construct_image_from_string,
EmnistLinesDataset,
@@ -13,7 +8,14 @@ from .emnist_lines_dataset import (
from .iam_dataset import IamDataset
from .iam_lines_dataset import IamLinesDataset
from .iam_paragraphs_dataset import IamParagraphsDataset
-from .util import _download_raw_dataset, compute_sha256, download_url, Transpose
+from .util import (
+ _download_raw_dataset,
+ compute_sha256,
+ DATA_DIRNAME,
+ download_url,
+ EmnistMapper,
+ ESSENTIALS_FILENAME,
+)
__all__ = [
"_download_raw_dataset",
diff --git a/src/text_recognizer/datasets/dataset.py b/src/text_recognizer/datasets/dataset.py
new file mode 100644
index 0000000..f328a0f
--- /dev/null
+++ b/src/text_recognizer/datasets/dataset.py
@@ -0,0 +1,124 @@
+"""Abstract dataset class."""
+from typing import Callable, Dict, Optional, Tuple, Union
+
+import torch
+from torch import Tensor
+from torch.utils import data
+from torchvision.transforms import ToTensor
+
+from text_recognizer.datasets.util import EmnistMapper
+
+
+class Dataset(data.Dataset):
+ """Abstract class for with common methods for all datasets."""
+
+ def __init__(
+ self,
+ train: bool,
+ subsample_fraction: float = None,
+ transform: Optional[Callable] = None,
+ target_transform: Optional[Callable] = None,
+ ) -> None:
+ """Initialization of Dataset class.
+
+ Args:
+ train (bool): If True, loads the training set, otherwise the validation set is loaded. Defaults to False.
+ subsample_fraction (float): Description of parameter `subsample_fraction`. Defaults to None.
+ transform (Optional[Callable]): Transform(s) for input data. Defaults to None.
+ target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
+
+ Raises:
+ ValueError: If subsample_fraction is not None and outside the range (0, 1).
+
+ """
+ self.train = train
+ self.split = "train" if self.train else "test"
+
+ if subsample_fraction is not None:
+ if not 0.0 < subsample_fraction < 1.0:
+ raise ValueError("The subsample fraction must be in (0, 1).")
+ self.subsample_fraction = subsample_fraction
+
+ self._mapper = EmnistMapper()
+ self._input_shape = self._mapper.input_shape
+ self._output_shape = self._mapper._num_classes
+ self.num_classes = self.mapper.num_classes
+
+ # Set transforms.
+ self.transform = transform
+ if self.transform is None:
+ self.transform = ToTensor()
+
+ self.target_transform = target_transform
+ if self.target_transform is None:
+ self.target_transform = torch.tensor
+
+ self._data = None
+ self._targets = None
+
+ @property
+ def data(self) -> Tensor:
+ """The input data."""
+ return self._data
+
+ @property
+ def targets(self) -> Tensor:
+ """The target data."""
+ return self._targets
+
+ @property
+ def input_shape(self) -> Tuple:
+ """Input shape of the data."""
+ return self._input_shape
+
+ @property
+ def output_shape(self) -> Tuple:
+ """Output shape of the data."""
+ return self._output_shape
+
+ @property
+ def mapper(self) -> EmnistMapper:
+ """Returns the EmnistMapper."""
+ return self._mapper
+
+ @property
+ def mapping(self) -> Dict:
+ """Return EMNIST mapping from index to character."""
+ return self._mapper.mapping
+
+ @property
+ def inverse_mapping(self) -> Dict:
+ """Returns the inverse mapping from character to index."""
+ return self.mapper.inverse_mapping
+
+ def _subsample(self) -> None:
+ """Only this fraction of the data will be loaded."""
+ if self.subsample_fraction is None:
+ return
+ num_subsample = int(self.data.shape[0] * self.subsample_fraction)
+ self.data = self.data[:num_subsample]
+ self.targets = self.targets[:num_subsample]
+
+ def __len__(self) -> int:
+ """Returns the length of the dataset."""
+ return len(self.data)
+
+ def load_or_generate_data(self) -> None:
+ """Load or generate dataset data."""
+ raise NotImplementedError
+
+ def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
+ """Fetches samples from the dataset.
+
+ Args:
+ index (Union[int, torch.Tensor]): The indices of the samples to fetch.
+
+ Raises:
+ NotImplementedError: If the method is not implemented in child class.
+
+ """
+ raise NotImplementedError
+
+ def __repr__(self) -> str:
+ """Returns information about the dataset."""
+ raise NotImplementedError
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/src/text_recognizer/datasets/emnist_dataset.py
index 0715aae..81268fb 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/src/text_recognizer/datasets/emnist_dataset.py
@@ -2,139 +2,26 @@
import json
from pathlib import Path
-from typing import Callable, Dict, List, Optional, Tuple, Type, Union
+from typing import Callable, Optional, Tuple, Union
from loguru import logger
import numpy as np
from PIL import Image
import torch
from torch import Tensor
-from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import EMNIST
-from torchvision.transforms import Compose, Normalize, ToTensor
+from torchvision.transforms import Compose, ToTensor
-from text_recognizer.datasets.util import Transpose
+from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.util import DATA_DIRNAME
-DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
-ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json"
+class Transpose:
+ """Transposes the EMNIST image to the correct orientation."""
-def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None:
- """Extract and saves EMNIST essentials."""
- labels = emnsit_dataset.classes
- labels.sort()
- mapping = [(i, str(label)) for i, label in enumerate(labels)]
- essentials = {
- "mapping": mapping,
- "input_shape": tuple(emnsit_dataset[0][0].shape[:]),
- }
- logger.info("Saving emnist essentials...")
- with open(ESSENTIALS_FILENAME, "w") as f:
- json.dump(essentials, f)
-
-
-def download_emnist() -> None:
- """Download the EMNIST dataset via the PyTorch class."""
- logger.info(f"Data directory is: {DATA_DIRNAME}")
- dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True)
- save_emnist_essentials(dataset)
-
-
-class EmnistMapper:
- """Mapper between network output to Emnist character."""
-
- def __init__(self) -> None:
- """Loads the emnist essentials file with the mapping and input shape."""
- self.essentials = self._load_emnist_essentials()
- # Load dataset infromation.
- self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"]))
- self._inverse_mapping = {v: k for k, v in self.mapping.items()}
- self._num_classes = len(self.mapping)
- self._input_shape = self.essentials["input_shape"]
-
- def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]:
- """Maps the token to emnist character or character index.
-
- If the token is an integer (index), the method will return the Emnist character corresponding to that index.
- If the token is a str (Emnist character), the method will return the corresponding index for that character.
-
- Args:
- token (Union[str, int, np.uint8]): Eihter a string or index (integer).
-
- Returns:
- Union[str, int]: The mapping result.
-
- Raises:
- KeyError: If the index or string does not exist in the mapping.
-
- """
- if (isinstance(token, np.uint8) or isinstance(token, int)) and int(
- token
- ) in self.mapping:
- return self.mapping[int(token)]
- elif isinstance(token, str) and token in self._inverse_mapping:
- return self._inverse_mapping[token]
- else:
- raise KeyError(f"Token {token} does not exist in the mappings.")
-
- @property
- def mapping(self) -> Dict:
- """Returns the mapping between index and character."""
- return self._mapping
-
- @property
- def inverse_mapping(self) -> Dict:
- """Returns the mapping between character and index."""
- return self._inverse_mapping
-
- @property
- def num_classes(self) -> int:
- """Returns the number of classes in the dataset."""
- return self._num_classes
-
- @property
- def input_shape(self) -> List[int]:
- """Returns the input shape of the Emnist characters."""
- return self._input_shape
-
- def _load_emnist_essentials(self) -> Dict:
- """Load the EMNIST mapping."""
- with open(str(ESSENTIALS_FILENAME)) as f:
- essentials = json.load(f)
- return essentials
-
- def _augment_emnist_mapping(self, mapping: Dict) -> Dict:
- """Augment the mapping with extra symbols."""
- # Extra symbols in IAM dataset
- extra_symbols = [
- " ",
- "!",
- '"',
- "#",
- "&",
- "'",
- "(",
- ")",
- "*",
- "+",
- ",",
- "-",
- ".",
- "/",
- ":",
- ";",
- "?",
- ]
-
- # padding symbol
- extra_symbols.append("_")
-
- max_key = max(mapping.keys())
- extra_mapping = {}
- for i, symbol in enumerate(extra_symbols):
- extra_mapping[max_key + 1 + i] = symbol
-
- return {**mapping, **extra_mapping}
+ def __call__(self, image: Image) -> np.ndarray:
+ """Swaps axis."""
+ return np.array(image).swapaxes(0, 1)
class EmnistDataset(Dataset):
@@ -159,70 +46,33 @@ class EmnistDataset(Dataset):
target_transform (Optional[Callable]): Transform(s) for output data. Defaults to None.
seed (int): Seed number. Defaults to 4711.
- Raises:
- ValueError: If subsample_fraction is not None and outside the range (0, 1).
-
"""
+ super().__init__(
+ train=train,
+ subsample_fraction=subsample_fraction,
+ transform=transform,
+ target_transform=target_transform,
+ )
- self.train = train
self.sample_to_balance = sample_to_balance
- if subsample_fraction is not None:
- if not 0.0 < subsample_fraction < 1.0:
- raise ValueError("The subsample fraction must be in (0, 1).")
- self.subsample_fraction = subsample_fraction
-
- self.transform = transform
- if self.transform is None:
+ # Have to transpose the emnist characters, ToTensor norms input between [0,1].
+ if transform is None:
self.transform = Compose([Transpose(), ToTensor()])
+ # The EMNIST dataset is already casted to tensors.
self.target_transform = target_transform
- self.seed = seed
-
- self._mapper = EmnistMapper()
- self._input_shape = self._mapper.input_shape
- self.num_classes = self._mapper.num_classes
-
- # Load dataset.
- self._data, self._targets = self.load_emnist_dataset()
-
- @property
- def data(self) -> Tensor:
- """The input data."""
- return self._data
- @property
- def targets(self) -> Tensor:
- """The target data."""
- return self._targets
-
- @property
- def input_shape(self) -> Tuple:
- """Input shape of the data."""
- return self._input_shape
-
- @property
- def mapper(self) -> EmnistMapper:
- """Returns the EmnistMapper."""
- return self._mapper
-
- @property
- def inverse_mapping(self) -> Dict:
- """Returns the inverse mapping from character to index."""
- return self.mapper.inverse_mapping
-
- def __len__(self) -> int:
- """Returns the length of the dataset."""
- return len(self.data)
+ self.seed = seed
def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
"""Fetches samples from the dataset.
Args:
- index (Union[int, torch.Tensor]): The indices of the samples to fetch.
+ index (Union[int, Tensor]): The indices of the samples to fetch.
Returns:
- Tuple[torch.Tensor, torch.Tensor]: Data target tuple.
+ Tuple[Tensor, Tensor]: Data target tuple.
"""
if torch.is_tensor(index):
@@ -248,13 +98,11 @@ class EmnistDataset(Dataset):
f"Mapping: {self.mapper.mapping}\n"
)
- def _sample_to_balance(
- self, data: Tensor, targets: Tensor
- ) -> Tuple[np.ndarray, np.ndarray]:
+ def _sample_to_balance(self) -> None:
"""Because the dataset is not balanced, we take at most the mean number of instances per class."""
np.random.seed(self.seed)
- x = data
- y = targets
+ x = self._data
+ y = self._targets
num_to_sample = int(np.bincount(y.flatten()).mean())
all_sampled_indices = []
for label in np.unique(y.flatten()):
@@ -264,22 +112,10 @@ class EmnistDataset(Dataset):
indices = np.concatenate(all_sampled_indices)
x_sampled = x[indices]
y_sampled = y[indices]
- data = x_sampled
- targets = y_sampled
- return data, targets
-
- def _subsample(self, data: Tensor, targets: Tensor) -> Tuple[Tensor, Tensor]:
- """Subsamples the dataset to the specified fraction."""
- x = data
- y = targets
- num_samples = int(x.shape[0] * self.subsample_fraction)
- x_sampled = x[:num_samples]
- y_sampled = y[:num_samples]
- self.data = x_sampled
- self.targets = y_sampled
- return data, targets
+ self._data = x_sampled
+ self._targets = y_sampled
- def load_emnist_dataset(self) -> Tuple[Tensor, Tensor]:
+ def load_or_generate_data(self) -> None:
"""Fetch the EMNIST dataset."""
dataset = EMNIST(
root=DATA_DIRNAME,
@@ -290,13 +126,11 @@ class EmnistDataset(Dataset):
target_transform=None,
)
- data = dataset.data
- targets = dataset.targets
+ self._data = dataset.data
+ self._targets = dataset.targets
if self.sample_to_balance:
- data, targets = self._sample_to_balance(data, targets)
+ self._sample_to_balance()
if self.subsample_fraction is not None:
- data, targets = self._subsample(data, targets)
-
- return data, targets
+ self._subsample()
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/src/text_recognizer/datasets/emnist_lines_dataset.py
index 656131a..8fa77cd 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/src/text_recognizer/datasets/emnist_lines_dataset.py
@@ -9,17 +9,16 @@ from loguru import logger
import numpy as np
import torch
from torch import Tensor
-from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
-from text_recognizer.datasets import (
+from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.emnist_dataset import EmnistDataset, Transpose
+from text_recognizer.datasets.sentence_generator import SentenceGenerator
+from text_recognizer.datasets.util import (
DATA_DIRNAME,
- EmnistDataset,
EmnistMapper,
ESSENTIALS_FILENAME,
)
-from text_recognizer.datasets.sentence_generator import SentenceGenerator
-from text_recognizer.datasets.util import Transpose
from text_recognizer.networks import sliding_window
DATA_DIRNAME = DATA_DIRNAME / "processed" / "emnist_lines"
@@ -52,18 +51,11 @@ class EmnistLinesDataset(Dataset):
seed (int): Seed number. Defaults to 4711.
"""
- self.train = train
-
- self.transform = transform
- if self.transform is None:
- self.transform = ToTensor()
-
- self.target_transform = target_transform
- if self.target_transform is None:
- self.target_transform = torch.tensor
+ super().__init__(
+ train=train, transform=transform, target_transform=target_transform,
+ )
# Extract dataset information.
- self._mapper = EmnistMapper()
self._input_shape = self._mapper.input_shape
self.num_classes = self._mapper.num_classes
@@ -75,24 +67,12 @@ class EmnistLinesDataset(Dataset):
self.input_shape[0],
self.input_shape[1] * self.max_length,
)
- self.output_shape = (self.max_length, self.num_classes)
+ self._output_shape = (self.max_length, self.num_classes)
self.seed = seed
# Placeholders for the dataset.
- self.data = None
- self.target = None
-
- # Load dataset.
- self._load_or_generate_data()
-
- @property
- def input_shape(self) -> Tuple:
- """Input shape of the data."""
- return self._input_shape
-
- def __len__(self) -> int:
- """Returns the length of the dataset."""
- return len(self.data)
+ self._data = None
+ self._target = None
def __getitem__(self, index: Union[int, Tensor]) -> Tuple[Tensor, Tensor]:
"""Fetches data, target pair of the dataset for a given and index or indices.
@@ -132,16 +112,6 @@ class EmnistLinesDataset(Dataset):
)
@property
- def mapper(self) -> EmnistMapper:
- """Returns the EmnistMapper."""
- return self._mapper
-
- @property
- def mapping(self) -> Dict:
- """Return EMNIST mapping from index to character."""
- return self._mapper.mapping
-
- @property
def data_filename(self) -> Path:
"""Path to the h5 file."""
filename = f"ml_{self.max_length}_o{self.min_overlap}_{self.max_overlap}_n{self.num_samples}.pt"
@@ -151,7 +121,7 @@ class EmnistLinesDataset(Dataset):
filename = "test_" + filename
return DATA_DIRNAME / filename
- def _load_or_generate_data(self) -> None:
+ def load_or_generate_data(self) -> None:
"""Loads the dataset, if it does not exist a new dataset is generated before loading it."""
np.random.seed(self.seed)
@@ -163,8 +133,8 @@ class EmnistLinesDataset(Dataset):
"""Loads the dataset from the h5 file."""
logger.debug("EmnistLinesDataset loading data from HDF5...")
with h5py.File(self.data_filename, "r") as f:
- self.data = f["data"][:]
- self.targets = f["targets"][:]
+ self._data = f["data"][:]
+ self._targets = f["targets"][:]
def _generate_data(self) -> str:
"""Generates a dataset with the Brown corpus and Emnist characters."""
diff --git a/src/text_recognizer/datasets/iam_dataset.py b/src/text_recognizer/datasets/iam_dataset.py
index 5e47350..f4a869d 100644
--- a/src/text_recognizer/datasets/iam_dataset.py
+++ b/src/text_recognizer/datasets/iam_dataset.py
@@ -7,10 +7,8 @@ from boltons.cacheutils import cachedproperty
import defusedxml.ElementTree as ET
from loguru import logger
import toml
-from torch.utils.data import Dataset
-from text_recognizer.datasets import DATA_DIRNAME
-from text_recognizer.datasets.util import _download_raw_dataset
+from text_recognizer.datasets.util import _download_raw_dataset, DATA_DIRNAME
RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam"
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
@@ -20,7 +18,7 @@ DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be.
LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates.
-class IamDataset(Dataset):
+class IamDataset:
"""IAM dataset.
"The IAM Lines dataset, first published at the ICDAR 1999, contains forms of unconstrained handwritten text,
diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/src/text_recognizer/datasets/iam_lines_dataset.py
index 477f500..4a74b2b 100644
--- a/src/text_recognizer/datasets/iam_lines_dataset.py
+++ b/src/text_recognizer/datasets/iam_lines_dataset.py
@@ -5,11 +5,15 @@ import h5py
from loguru import logger
import torch
from torch import Tensor
-from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
-from text_recognizer.datasets.emnist_dataset import DATA_DIRNAME, EmnistMapper
-from text_recognizer.datasets.util import compute_sha256, download_url
+from text_recognizer.datasets.dataset import Dataset
+from text_recognizer.datasets.util import (
+ compute_sha256,
+ DATA_DIRNAME,
+ download_url,
+ EmnistMapper,
+)
PROCESSED_DATA_DIRNAME = DATA_DIRNAME / "processed" / "iam_lines"
@@ -29,47 +33,26 @@ class IamLinesDataset(Dataset):
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
- self.train = train
- self.split = "train" if self.train else "test"
- self._mapper = EmnistMapper()
- self.num_classes = self.mapper.num_classes
-
- # Set transforms.
- self.transform = transform
- if self.transform is None:
- self.transform = ToTensor()
-
- self.target_transform = target_transform
- if self.target_transform is None:
- self.target_transform = torch.tensor
-
- self.subsample_fraction = subsample_fraction
- self.data = None
- self.targets = None
-
- @property
- def mapper(self) -> EmnistMapper:
- """Returns the EmnistMapper."""
- return self._mapper
-
- @property
- def mapping(self) -> Dict:
- """Return EMNIST mapping from index to character."""
- return self._mapper.mapping
+ super().__init__(
+ train=train,
+ subsample_fraction=subsample_fraction,
+ transform=transform,
+ target_transform=target_transform,
+ )
@property
def input_shape(self) -> Tuple:
"""Input shape of the data."""
- return self.data.shape[1:]
+ return self.data.shape[1:] if self.data is not None else None
@property
def output_shape(self) -> Tuple:
"""Output shape of the data."""
- return self.targets.shape[1:] + (self.num_classes,)
-
- def __len__(self) -> int:
- """Returns the length of the dataset."""
- return len(self.data)
+ return (
+ self.targets.shape[1:] + (self.num_classes,)
+ if self.targets is not None
+ else None
+ )
def load_or_generate_data(self) -> None:
"""Load or generate dataset data."""
@@ -78,19 +61,10 @@ class IamLinesDataset(Dataset):
logger.info("Downloading IAM lines...")
download_url(PROCESSED_DATA_URL, PROCESSED_DATA_FILENAME)
with h5py.File(PROCESSED_DATA_FILENAME, "r") as f:
- self.data = f[f"x_{self.split}"][:]
- self.targets = f[f"y_{self.split}"][:]
+ self._data = f[f"x_{self.split}"][:]
+ self._targets = f[f"y_{self.split}"][:]
self._subsample()
- def _subsample(self) -> None:
- """Only a fraction of the data will be loaded."""
- if self.subsample_fraction is None:
- return
-
- num_samples = int(self.data.shape[0] * self.subsample_fraction)
- self.data = self.data[:num_samples]
- self.targets = self.targets[:num_samples]
-
def __repr__(self) -> str:
"""Print info about the dataset."""
return (
diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/src/text_recognizer/datasets/iam_paragraphs_dataset.py
index d65b346..4b34bd1 100644
--- a/src/text_recognizer/datasets/iam_paragraphs_dataset.py
+++ b/src/text_recognizer/datasets/iam_paragraphs_dataset.py
@@ -8,13 +8,17 @@ from loguru import logger
import numpy as np
import torch
from torch import Tensor
-from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from text_recognizer import util
-from text_recognizer.datasets.emnist_dataset import DATA_DIRNAME, EmnistMapper
+from text_recognizer.datasets.dataset import Dataset
from text_recognizer.datasets.iam_dataset import IamDataset
-from text_recognizer.datasets.util import compute_sha256, download_url
+from text_recognizer.datasets.util import (
+ compute_sha256,
+ DATA_DIRNAME,
+ download_url,
+ EmnistMapper,
+)
INTERIM_DATA_DIRNAME = DATA_DIRNAME / "interim" / "iam_paragraphs"
DEBUG_CROPS_DIRNAME = INTERIM_DATA_DIRNAME / "debug_crops"
@@ -28,11 +32,7 @@ SEED = 4711
class IamParagraphsDataset(Dataset):
- """IAM Paragraphs dataset for paragraphs of handwritten text.
-
- TODO: __getitem__, __len__, get_data_target_from_id
-
- """
+ """IAM Paragraphs dataset for paragraphs of handwritten text."""
def __init__(
self,
@@ -41,34 +41,20 @@ class IamParagraphsDataset(Dataset):
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
-
+ super().__init__(
+ train=train,
+ subsample_fraction=subsample_fraction,
+ transform=transform,
+ target_transform=target_transform,
+ )
# Load Iam dataset.
self.iam_dataset = IamDataset()
- self.train = train
- self.split = "train" if self.train else "test"
self.num_classes = 3
self._input_shape = (256, 256)
self._output_shape = self._input_shape + (self.num_classes,)
- self.subsample_fraction = subsample_fraction
-
- # Set transforms.
- self.transform = transform
- if self.transform is None:
- self.transform = ToTensor()
-
- self.target_transform = target_transform
- if self.target_transform is None:
- self.target_transform = torch.tensor
-
- self._data = None
- self._targets = None
self._ids = None
- def __len__(self) -> int:
- """Returns the length of the dataset."""
- return len(self.data)
-
def __getitem__(self, index: Union[Tensor, int]) -> Tuple[Tensor, Tensor]:
"""Fetches data, target pair of the dataset for a given and index or indices.
@@ -94,26 +80,6 @@ class IamParagraphsDataset(Dataset):
return data, targets
@property
- def input_shape(self) -> Tuple:
- """Input shape of the data."""
- return self._input_shape
-
- @property
- def output_shape(self) -> Tuple:
- """Output shape of the data."""
- return self._output_shape
-
- @property
- def data(self) -> Tensor:
- """The input data."""
- return self._data
-
- @property
- def targets(self) -> Tensor:
- """The target data."""
- return self._targets
-
- @property
def ids(self) -> Tensor:
"""Ids of the dataset."""
return self._ids
@@ -201,14 +167,6 @@ class IamParagraphsDataset(Dataset):
logger.info(f"Setting them to {max_crop_width}x{max_crop_width}")
return crop_dims
- def _subsample(self) -> None:
- """Only this fraction of the data will be loaded."""
- if self.subsample_fraction is None:
- return
- num_subsample = int(self.data.shape[0] * self.subsample_fraction)
- self.data = self.data[:num_subsample]
- self.targets = self.targets[:num_subsample]
-
def __repr__(self) -> str:
"""Return info about the dataset."""
return (
diff --git a/src/text_recognizer/datasets/sentence_generator.py b/src/text_recognizer/datasets/sentence_generator.py
index ee86bd4..dd76652 100644
--- a/src/text_recognizer/datasets/sentence_generator.py
+++ b/src/text_recognizer/datasets/sentence_generator.py
@@ -9,7 +9,7 @@ import nltk
from nltk.corpus.reader.util import ConcatenatedCorpusView
import numpy as np
-from text_recognizer.datasets import DATA_DIRNAME
+from text_recognizer.datasets.util import DATA_DIRNAME
NLTK_DATA_DIRNAME = DATA_DIRNAME / "raw" / "nltk"
diff --git a/src/text_recognizer/datasets/util.py b/src/text_recognizer/datasets/util.py
index dd16bed..3acf5db 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/src/text_recognizer/datasets/util.py
@@ -1,6 +1,7 @@
"""Util functions for datasets."""
import hashlib
import importlib
+import json
import os
from pathlib import Path
from typing import Callable, Dict, List, Optional, Type, Union
@@ -11,15 +12,129 @@ from loguru import logger
import numpy as np
from PIL import Image
from torch.utils.data import DataLoader, Dataset
+from torchvision.datasets import EMNIST
from tqdm import tqdm
+DATA_DIRNAME = Path(__file__).resolve().parents[3] / "data"
+ESSENTIALS_FILENAME = Path(__file__).resolve().parents[0] / "emnist_essentials.json"
-class Transpose:
- """Transposes the EMNIST image to the correct orientation."""
- def __call__(self, image: Image) -> np.ndarray:
- """Swaps axis."""
- return np.array(image).swapaxes(0, 1)
+def save_emnist_essentials(emnsit_dataset: type = EMNIST) -> None:
+ """Extract and saves EMNIST essentials."""
+ labels = emnsit_dataset.classes
+ labels.sort()
+ mapping = [(i, str(label)) for i, label in enumerate(labels)]
+ essentials = {
+ "mapping": mapping,
+ "input_shape": tuple(emnsit_dataset[0][0].shape[:]),
+ }
+ logger.info("Saving emnist essentials...")
+ with open(ESSENTIALS_FILENAME, "w") as f:
+ json.dump(essentials, f)
+
+
+def download_emnist() -> None:
+ """Download the EMNIST dataset via the PyTorch class."""
+ logger.info(f"Data directory is: {DATA_DIRNAME}")
+ dataset = EMNIST(root=DATA_DIRNAME, split="byclass", download=True)
+ save_emnist_essentials(dataset)
+
+
+class EmnistMapper:
+ """Mapper between network output to Emnist character."""
+
+ def __init__(self) -> None:
+ """Loads the emnist essentials file with the mapping and input shape."""
+ self.essentials = self._load_emnist_essentials()
+ # Load dataset infromation.
+ self._mapping = self._augment_emnist_mapping(dict(self.essentials["mapping"]))
+ self._inverse_mapping = {v: k for k, v in self.mapping.items()}
+ self._num_classes = len(self.mapping)
+ self._input_shape = self.essentials["input_shape"]
+
+ def __call__(self, token: Union[str, int, np.uint8]) -> Union[str, int]:
+ """Maps the token to emnist character or character index.
+
+ If the token is an integer (index), the method will return the Emnist character corresponding to that index.
+ If the token is a str (Emnist character), the method will return the corresponding index for that character.
+
+ Args:
+ token (Union[str, int, np.uint8]): Eihter a string or index (integer).
+
+ Returns:
+ Union[str, int]: The mapping result.
+
+ Raises:
+ KeyError: If the index or string does not exist in the mapping.
+
+ """
+ if (isinstance(token, np.uint8) or isinstance(token, int)) and int(
+ token
+ ) in self.mapping:
+ return self.mapping[int(token)]
+ elif isinstance(token, str) and token in self._inverse_mapping:
+ return self._inverse_mapping[token]
+ else:
+ raise KeyError(f"Token {token} does not exist in the mappings.")
+
+ @property
+ def mapping(self) -> Dict:
+ """Returns the mapping between index and character."""
+ return self._mapping
+
+ @property
+ def inverse_mapping(self) -> Dict:
+ """Returns the mapping between character and index."""
+ return self._inverse_mapping
+
+ @property
+ def num_classes(self) -> int:
+ """Returns the number of classes in the dataset."""
+ return self._num_classes
+
+ @property
+ def input_shape(self) -> List[int]:
+ """Returns the input shape of the Emnist characters."""
+ return self._input_shape
+
+ def _load_emnist_essentials(self) -> Dict:
+ """Load the EMNIST mapping."""
+ with open(str(ESSENTIALS_FILENAME)) as f:
+ essentials = json.load(f)
+ return essentials
+
+ def _augment_emnist_mapping(self, mapping: Dict) -> Dict:
+ """Augment the mapping with extra symbols."""
+ # Extra symbols in IAM dataset
+ extra_symbols = [
+ " ",
+ "!",
+ '"',
+ "#",
+ "&",
+ "'",
+ "(",
+ ")",
+ "*",
+ "+",
+ ",",
+ "-",
+ ".",
+ "/",
+ ":",
+ ";",
+ "?",
+ ]
+
+ # padding symbol
+ extra_symbols.append("_")
+
+ max_key = max(mapping.keys())
+ extra_mapping = {}
+ for i, symbol in enumerate(extra_symbols):
+ extra_mapping[max_key + 1 + i] = symbol
+
+ return {**mapping, **extra_mapping}
def compute_sha256(filename: Union[Path, str]) -> str:
diff --git a/src/text_recognizer/models/base.py b/src/text_recognizer/models/base.py
index 153e19a..d23fe56 100644
--- a/src/text_recognizer/models/base.py
+++ b/src/text_recognizer/models/base.py
@@ -140,6 +140,7 @@ class Model(ABC):
if not self.data_prepared:
# Load train dataset.
train_dataset = self.dataset(train=True, **self.dataset_args["args"])
+ train_dataset.load_or_generate_data()
# Set input shape.
self._input_shape = train_dataset.input_shape
@@ -156,6 +157,7 @@ class Model(ABC):
# Load test dataset.
self.test_dataset = self.dataset(train=False, **self.dataset_args["args"])
+ self.test_dataset.load_or_generate_data()
# Set the flag to true to disable ability to load data agian.
self.data_prepared = True
diff --git a/src/text_recognizer/networks/ctc.py b/src/text_recognizer/networks/ctc.py
index fc0d21d..72f18b8 100644
--- a/src/text_recognizer/networks/ctc.py
+++ b/src/text_recognizer/networks/ctc.py
@@ -5,7 +5,7 @@ from einops import rearrange
import torch
from torch import Tensor
-from text_recognizer.datasets import EmnistMapper
+from text_recognizer.datasets.util import EmnistMapper
def greedy_decoder(