summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/notebooks/02b-emnist-lines-dataset.ipynb14
-rw-r--r--src/notebooks/Untitled.ipynb289
-rw-r--r--src/text_recognizer/datasets/transforms.py8
-rw-r--r--src/text_recognizer/networks/crnn.py12
4 files changed, 82 insertions, 241 deletions
diff --git a/src/notebooks/02b-emnist-lines-dataset.ipynb b/src/notebooks/02b-emnist-lines-dataset.ipynb
index 0f2626f..a9b13b4 100644
--- a/src/notebooks/02b-emnist-lines-dataset.ipynb
+++ b/src/notebooks/02b-emnist-lines-dataset.ipynb
@@ -31,28 +31,28 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
- "emnist_lines = EmnistLinesDataset(train=False,\n",
- " max_length = 34,\n",
+ "emnist_lines = EmnistLinesDataset(train=True,\n",
+ " max_length = 97,\n",
" min_overlap = 0.0,\n",
" max_overlap = 0.33,\n",
- " num_samples = 5_000,)"
+ " num_samples = 50_000,)"
]
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "2020-11-12 08:12:02.064 | DEBUG | text_recognizer.datasets.emnist_lines_dataset:_generate_data:154 - Generating data...\n",
- "2020-11-12 08:12:05.917 | DEBUG | text_recognizer.datasets.emnist_lines_dataset:_load_data:147 - EmnistLinesDataset loading data from HDF5...\n"
+ "2020-11-15 19:49:33.374 | DEBUG | text_recognizer.datasets.emnist_lines_dataset:_generate_data:154 - Generating data...\n",
+ "2020-11-15 19:50:10.082 | DEBUG | text_recognizer.datasets.emnist_lines_dataset:_load_data:147 - EmnistLinesDataset loading data from HDF5...\n"
]
}
],
diff --git a/src/notebooks/Untitled.ipynb b/src/notebooks/Untitled.ipynb
index f114ed9..208f098 100644
--- a/src/notebooks/Untitled.ipynb
+++ b/src/notebooks/Untitled.ipynb
@@ -40,19 +40,28 @@
"metadata": {},
"outputs": [],
"source": [
- "from text_recognizer.models import VisionTransformerModel, TransformerEncoderModel\n",
- "from text_recognizer.datasets import IamLinesDataset\n",
- "from text_recognizer.datasets.transforms import Compose, AddTokens"
+ "from text_recognizer.models import TransformerModel\n",
+ "from text_recognizer.datasets import IamLinesDataset"
]
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
- "target_transform = Compose([torch.tensor, AddTokens(init_token=\"<sos>\", pad_token=\"_\", eos_token=\"<eos>\")])\n",
- "dataset = IamLinesDataset(train=False, init_token=\"<sos>\", pad_token=\"_\", eos_token=\"<eos>\", target_transform=target_transform)\n",
+ "dataset = IamLinesDataset(train=False,\n",
+ " init_token=\"<sos>\",\n",
+ " pad_token=\"_\",\n",
+ " eos_token=\"<eos>\",\n",
+ " transform=[{\"type\": \"ToTensor\", \"args\": {}}],\n",
+ " target_transform=[\n",
+ " {\n",
+ " \"type\": \"AddTokens\",\n",
+ " \"args\": {\"init_token\": \"<sos>\", \"pad_token\": \"_\", \"eos_token\": \"<eos>\"},\n",
+ " }\n",
+ " ],\n",
+ " )\n",
"dataset.load_or_generate_data()"
]
},
@@ -62,55 +71,53 @@
"metadata": {},
"outputs": [],
"source": [
- "config_path = \"../training/experiments/VisionTransformerModel_IamLinesDataset_CNNTransformer/1102_221553/config.yml\"\n",
+ "config_path = \"../training/experiments/TransformerModel_IamLinesDataset_CNNTransformer/1116_082932/config.yml\"\n",
"with open(config_path, \"r\") as f:\n",
" experiment_config = yaml.safe_load(f)"
]
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": 8,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'CNNTransformer'"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
- "dataset_args = experiment_config.get(\"dataset\", {})\n",
- "datasets_module = importlib.import_module(\"text_recognizer.datasets\")\n",
- "dataset_ = getattr(datasets_module, dataset_args[\"type\"])\n",
- "\n",
- "network_module = importlib.import_module(\"text_recognizer.networks\")\n",
- "network_fn_ = getattr(network_module, experiment_config[\"network\"][\"type\"])"
+ "experiment_config[\"network\"][\"type\"]"
]
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
- "2020-11-03 07:32:07.256 | DEBUG | text_recognizer.models.base:load_weights:457 - Loading network with pretrained weights.\n"
+ "2020-11-16 20:07:51.973 | DEBUG | text_recognizer.models.base:load_weights:432 - Loading network with pretrained weights.\n"
]
}
],
"source": [
- "model = VisionTransformerModel(network_fn=network_fn_, dataset=dataset_, dataset_args=dataset_args)"
+ "model = TransformerModel(network_fn=experiment_config[\"network\"][\"type\"], dataset=experiment_config[\"dataset\"][\"type\"], dataset_args=experiment_config[\"dataset\"])"
]
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "2020-11-03 07:32:10.285 | DEBUG | text_recognizer.models.base:load_from_checkpoint:404 - Loading checkpoint...\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"checkpoint_path = \"../training/experiments/VisionTransformerModel_IamLinesDataset_CNNTransformer/1102_221553/model/last.pt\"\n",
"model.load_from_checkpoint(checkpoint_path)"
@@ -118,7 +125,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
@@ -127,101 +134,71 @@
},
{
"cell_type": "code",
- "execution_count": 24,
+ "execution_count": 95,
"metadata": {},
"outputs": [],
"source": [
- "data, target = dataset[0]\n",
+ "data, target = dataset[1006]\n",
"sentence = convert_y_label_to_string(target, dataset) "
]
},
{
"cell_type": "code",
- "execution_count": 25,
+ "execution_count": 102,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([98])"
- ]
- },
- "execution_count": 25,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "target.shape"
+ "data1 = (data - data.mean()) / data.std()"
]
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": 103,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "([], [])"
+ "torch.Size([98])"
]
},
- "execution_count": 26,
+ "execution_count": 103,
"metadata": {},
"output_type": "execute_result"
- },
- {
- "data": {
- "image/png": "\n",
- "text/plain": [
- "<Figure size 1440x1440 with 1 Axes>"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
}
],
"source": [
- "plt.figure(figsize=(20, 20))\n",
- "plt.title(sentence)\n",
- "plt.imshow(data.squeeze(0).numpy(), cmap='gray')\n",
- "plt.xticks([])\n",
- "plt.yticks([])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 32,
- "metadata": {},
- "outputs": [],
- "source": [
- " def make_len_mask(inp):\n",
- " return (inp == 79).transpose(0, 1)"
+ "target.shape"
]
},
{
"cell_type": "code",
- "execution_count": 34,
+ "execution_count": 110,
"metadata": {},
"outputs": [
{
"data": {
+ "image/png": "\n",
"text/plain": [
- "torch.Size([98, 1])"
+ "<Figure size 2880x1440 with 1 Axes>"
]
},
- "execution_count": 34,
"metadata": {},
- "output_type": "execute_result"
+ "output_type": "display_data"
}
],
"source": [
- "make_len_mask(target.unsqueeze(0)).shape"
+ "plt.figure(figsize=(40, 20))\n",
+ "plt.title(sentence)\n",
+ "plt.imshow(data1.squeeze(0).numpy(), cmap='gray')\n",
+ "plt.xticks([])\n",
+ "plt.yticks([])\n",
+ "plt.show()"
]
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": 111,
"metadata": {
"scrolled": true
},
@@ -229,163 +206,17 @@
{
"data": {
"text/plain": [
- "('to stel mire of a thar chishirchit<eos>', 0.20226626098155975)"
+ "('Horbwargethers sis tHater alate Bate Bath Con Hats the Bateries.<eos>',\n",
+ " 0.2612667977809906)"
]
},
- "execution_count": 27,
+ "execution_count": 111,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "model.predict_on_image(data)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 103,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "2020-10-31 16:35:40.255 | DEBUG | text_recognizer.models.base:load_weights:457 - Loading network with pretrained weights.\n",
- "2020-10-31 16:35:40.837 | DEBUG | text_recognizer.models.base:load_from_checkpoint:404 - Loading checkpoint...\n"
- ]
- }
- ],
- "source": [
- "target_transform = Compose([torch.tensor, AddTokens(pad_token=\"_\", eos_token=\"<eos>\")])\n",
- "dataset = IamLinesDataset(train=False, pad_token=\"_\", eos_token=\"<eos>\", target_transform=target_transform)\n",
- "dataset.load_or_generate_data()\n",
- "\n",
- "\n",
- "config_path = \"../training/experiments/TransformerEncoderModel_IamLinesDataset_CNNTransformerEncoder/1031_150630/config.yml\"\n",
- "with open(config_path, \"r\") as f:\n",
- " experiment_config = yaml.safe_load(f)\n",
- "\n",
- "\n",
- "dataset_args = experiment_config.get(\"dataset\", {})\n",
- "datasets_module = importlib.import_module(\"text_recognizer.datasets\")\n",
- "dataset_ = getattr(datasets_module, dataset_args[\"type\"])\n",
- "\n",
- "network_module = importlib.import_module(\"text_recognizer.networks\")\n",
- "network_fn_ = getattr(network_module, experiment_config[\"network\"][\"type\"])\n",
- "\n",
- "\n",
- "checkpoint_path = \"../training/experiments/TransformerEncoderModel_IamLinesDataset_CNNTransformerEncoder/1031_150630/model/last.pt\"\n",
- "\n",
- "\n",
- "model = TransformerEncoderModel(network_fn=network_fn_, dataset=dataset_, dataset_args=dataset_args)\n",
- "model.load_from_checkpoint(checkpoint_path)\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 92,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "===============================================================================================\n",
- "Layer (type:depth-idx) Output Shape Param #\n",
- "===============================================================================================\n",
- "├─WideResidualNetwork: 1-1 [-1, 256, 2, 60] --\n",
- "| └─Sequential: 2-1 [-1, 256, 2, 60] --\n",
- "| | └─Conv2d: 3-1 [-1, 8, 28, 952] 72\n",
- "| | └─Sequential: 3-2 [-1, 16, 28, 952] --\n",
- "| | | └─WideBlock: 4-1 [-1, 16, 28, 952] 3,632\n",
- "| | └─Sequential: 3-3 [-1, 32, 14, 476] --\n",
- "| | | └─WideBlock: 4-2 [-1, 32, 14, 476] 14,432\n",
- "| | └─Sequential: 3-4 [-1, 64, 7, 238] --\n",
- "| | | └─WideBlock: 4-3 [-1, 64, 7, 238] 57,536\n",
- "| | └─Sequential: 3-5 [-1, 128, 4, 119] --\n",
- "| | | └─WideBlock: 4-4 [-1, 128, 4, 119] 229,760\n",
- "| | └─Sequential: 3-6 [-1, 256, 2, 60] --\n",
- "| | | └─WideBlock: 4-5 [-1, 256, 2, 60] 918,272\n",
- "├─Conv2d: 1-2 [-1, 97, 2, 60] 24,929\n",
- "├─Linear: 1-3 [-1, 97, 96] 11,616\n",
- "├─PositionalEncoding: 1-4 [-1, 97, 96] --\n",
- "| └─Dropout: 2-2 [-1, 97, 96] --\n",
- "├─TransformerEncoder: 1-5 [-1, 2, 96] --\n",
- "| └─ModuleList: 2 [] --\n",
- "| | └─TransformerEncoderLayer: 3-7 [-1, 2, 96] --\n",
- "| | | └─MultiheadAttention: 4-6 [-1, 2, 96] 37,248\n",
- "| | | └─Dropout: 4-7 [-1, 2, 96] --\n",
- "| | | └─LayerNorm: 4-8 [-1, 2, 96] 192\n",
- "| | | └─Linear: 4-9 [-1, 2, 2048] 198,656\n",
- "| | | └─Dropout: 4-10 [-1, 2, 2048] --\n",
- "| | | └─Linear: 4-11 [-1, 2, 96] 196,704\n",
- "| | | └─Dropout: 4-12 [-1, 2, 96] --\n",
- "| | | └─LayerNorm: 4-13 [-1, 2, 96] 192\n",
- "| | └─TransformerEncoderLayer: 3-8 [-1, 2, 96] --\n",
- "| | | └─MultiheadAttention: 4-14 [-1, 2, 96] 37,248\n",
- "| | | └─Dropout: 4-15 [-1, 2, 96] --\n",
- "| | | └─LayerNorm: 4-16 [-1, 2, 96] 192\n",
- "| | | └─Linear: 4-17 [-1, 2, 2048] 198,656\n",
- "| | | └─Dropout: 4-18 [-1, 2, 2048] --\n",
- "| | | └─Linear: 4-19 [-1, 2, 96] 196,704\n",
- "| | | └─Dropout: 4-20 [-1, 2, 96] --\n",
- "| | | └─LayerNorm: 4-21 [-1, 2, 96] 192\n",
- "| | └─TransformerEncoderLayer: 3-9 [-1, 2, 96] --\n",
- "| | | └─MultiheadAttention: 4-22 [-1, 2, 96] 37,248\n",
- "| | | └─Dropout: 4-23 [-1, 2, 96] --\n",
- "| | | └─LayerNorm: 4-24 [-1, 2, 96] 192\n",
- "| | | └─Linear: 4-25 [-1, 2, 2048] 198,656\n",
- "| | | └─Dropout: 4-26 [-1, 2, 2048] --\n",
- "| | | └─Linear: 4-27 [-1, 2, 96] 196,704\n",
- "| | | └─Dropout: 4-28 [-1, 2, 96] --\n",
- "| | | └─LayerNorm: 4-29 [-1, 2, 96] 192\n",
- "| | └─TransformerEncoderLayer: 3-10 [-1, 2, 96] --\n",
- "| | | └─MultiheadAttention: 4-30 [-1, 2, 96] 37,248\n",
- "| | | └─Dropout: 4-31 [-1, 2, 96] --\n",
- "| | | └─LayerNorm: 4-32 [-1, 2, 96] 192\n",
- "| | | └─Linear: 4-33 [-1, 2, 2048] 198,656\n",
- "| | | └─Dropout: 4-34 [-1, 2, 2048] --\n",
- "| | | └─Linear: 4-35 [-1, 2, 96] 196,704\n",
- "| | | └─Dropout: 4-36 [-1, 2, 96] --\n",
- "| | | └─LayerNorm: 4-37 [-1, 2, 96] 192\n",
- "| | └─TransformerEncoderLayer: 3-11 [-1, 2, 96] --\n",
- "| | | └─MultiheadAttention: 4-38 [-1, 2, 96] 37,248\n",
- "| | | └─Dropout: 4-39 [-1, 2, 96] --\n",
- "| | | └─LayerNorm: 4-40 [-1, 2, 96] 192\n",
- "| | | └─Linear: 4-41 [-1, 2, 2048] 198,656\n",
- "| | | └─Dropout: 4-42 [-1, 2, 2048] --\n",
- "| | | └─Linear: 4-43 [-1, 2, 96] 196,704\n",
- "| | | └─Dropout: 4-44 [-1, 2, 96] --\n",
- "| | | └─LayerNorm: 4-45 [-1, 2, 96] 192\n",
- "| | └─TransformerEncoderLayer: 3-12 [-1, 2, 96] --\n",
- "| | | └─MultiheadAttention: 4-46 [-1, 2, 96] 37,248\n",
- "| | | └─Dropout: 4-47 [-1, 2, 96] --\n",
- "| | | └─LayerNorm: 4-48 [-1, 2, 96] 192\n",
- "| | | └─Linear: 4-49 [-1, 2, 2048] 198,656\n",
- "| | | └─Dropout: 4-50 [-1, 2, 2048] --\n",
- "| | | └─Linear: 4-51 [-1, 2, 96] 196,704\n",
- "| | | └─Dropout: 4-52 [-1, 2, 96] --\n",
- "| | | └─LayerNorm: 4-53 [-1, 2, 96] 192\n",
- "| └─LayerNorm: 2-3 [-1, 2, 96] 192\n",
- "├─Linear: 1-6 [-1, 97, 81] 7,857\n",
- "===============================================================================================\n",
- "Total params: 3,866,250\n",
- "Trainable params: 3,866,250\n",
- "Non-trainable params: 0\n",
- "Total mult-adds (M): 18.78\n",
- "===============================================================================================\n",
- "Input size (MB): 0.10\n",
- "Forward/backward pass size (MB): 2.06\n",
- "Params size (MB): 14.75\n",
- "Estimated Total Size (MB): 16.91\n",
- "===============================================================================================\n"
- ]
- }
- ],
- "source": [
- "model.summary(experiment_config[\"train_args\"][\"input_shape\"], 4)"
+ "model.predict_on_image(data1)"
]
},
{
diff --git a/src/text_recognizer/datasets/transforms.py b/src/text_recognizer/datasets/transforms.py
index 1105f23..d1ca127 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/src/text_recognizer/datasets/transforms.py
@@ -64,3 +64,11 @@ class AddTokens:
target = torch.cat([sos, target], dim=0)
return target
+
+
+class Whitening:
+ """Whitening of Tensor, i.e. set mean to zero and std to one."""
+
+ def __call__(self, x: Tensor) -> Tensor:
+ """Apply the whitening."""
+ return (x - x.mean()) / x.std()
diff --git a/src/text_recognizer/networks/crnn.py b/src/text_recognizer/networks/crnn.py
index 9747429..778e232 100644
--- a/src/text_recognizer/networks/crnn.py
+++ b/src/text_recognizer/networks/crnn.py
@@ -1,4 +1,4 @@
-"""LSTM with CTC for handwritten text recognition within a line."""
+"""CRNN for handwritten text recognition."""
from typing import Dict, Tuple
from einops import rearrange, reduce
@@ -89,20 +89,22 @@ class ConvolutionalRecurrentNetwork(nn.Module):
x = self.backbone(x)
- # Avgerage pooling.
+ # Average pooling.
if self.avg_pool:
x = reduce(x, "(b t) c h w -> t b c", "mean", b=b, t=t)
else:
x = rearrange(x, "(b t) h -> t b h", b=b, t=t)
else:
# Encode the entire image with a CNN, and use the channels as temporal dimension.
- b = x.shape[0]
x = self.backbone(x)
- x = rearrange(x, "b c h w -> c b (h w)", b=b)
+ x = rearrange(x, "b c h w -> b w c h")
+ if self.adaptive_pool is not None:
+ x = self.adaptive_pool(x)
+ x = x.squeeze(3)
# Sequence predictions.
x, _ = self.rnn(x)
- # Sequence to classifcation layer.
+ # Sequence to classification layer.
x = self.decoder(x)
return x