summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--notebooks/00-scratch-pad.ipynb10
-rw-r--r--notebooks/01-look-at-emnist.ipynb4
-rw-r--r--notebooks/02b-look-at-emnist-lines.ipynb8
-rw-r--r--notebooks/03-look-at-iam-lines.ipynb4
-rw-r--r--notebooks/03-look-at-iam-paragraphs.ipynb8
-rw-r--r--notebooks/04b-look-at-iam-paragraphs-predictions.ipynb8
-rw-r--r--notebooks/04b-look-at-iam-paragraphs.ipynb8
-rw-r--r--notebooks/05c-test-model-end-to-end.ipynb340
-rw-r--r--text_recognizer/criterions/label_smoothing.py (renamed from text_recognizer/criterions/label_smoothing_loss.py)0
-rw-r--r--text_recognizer/data/base_dataset.py1
-rw-r--r--text_recognizer/data/emnist.py2
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py23
-rw-r--r--text_recognizer/data/iam_lines.py6
-rw-r--r--text_recognizer/data/iam_paragraphs.py7
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py12
-rw-r--r--text_recognizer/models/base.py31
-rw-r--r--text_recognizer/models/transformer.py26
-rw-r--r--text_recognizer/networks/base.py18
-rw-r--r--text_recognizer/networks/conv_transformer.py (renamed from text_recognizer/networks/cnn_tranformer.py)27
-rw-r--r--training/conf/criterion/label_smoothing.yaml4
-rw-r--r--training/conf/mapping/word_piece.yaml9
-rw-r--r--training/conf/model/lit_transformer.yaml4
-rw-r--r--training/conf/network/conv_transformer.yaml13
23 files changed, 237 insertions, 336 deletions
diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb
index 2c98064..0350727 100644
--- a/notebooks/00-scratch-pad.ipynb
+++ b/notebooks/00-scratch-pad.ipynb
@@ -49,9 +49,7 @@
{
"cell_type": "code",
"execution_count": 7,
- "metadata": {
- "scrolled": false
- },
+ "metadata": {},
"outputs": [],
"source": [
"en = EfficientNet(\"b0\")"
@@ -268,9 +266,7 @@
{
"cell_type": "code",
"execution_count": null,
- "metadata": {
- "scrolled": false
- },
+ "metadata": {},
"outputs": [],
"source": [
"summary(en, (1, 224, 224));"
@@ -1157,7 +1153,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
diff --git a/notebooks/01-look-at-emnist.ipynb b/notebooks/01-look-at-emnist.ipynb
index 5b5310e..1ca06c5 100644
--- a/notebooks/01-look-at-emnist.ipynb
+++ b/notebooks/01-look-at-emnist.ipynb
@@ -106,7 +106,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -120,7 +120,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.9.5"
+ "version": "3.9.6"
}
},
"nbformat": 4,
diff --git a/notebooks/02b-look-at-emnist-lines.ipynb b/notebooks/02b-look-at-emnist-lines.ipynb
index 93893f9..89045a4 100644
--- a/notebooks/02b-look-at-emnist-lines.ipynb
+++ b/notebooks/02b-look-at-emnist-lines.ipynb
@@ -136,9 +136,7 @@
{
"cell_type": "code",
"execution_count": 9,
- "metadata": {
- "scrolled": false
- },
+ "metadata": {},
"outputs": [
{
"data": {
@@ -270,7 +268,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -284,7 +282,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.9.1"
+ "version": "3.9.6"
}
},
"nbformat": 4,
diff --git a/notebooks/03-look-at-iam-lines.ipynb b/notebooks/03-look-at-iam-lines.ipynb
index ab12642..383381d 100644
--- a/notebooks/03-look-at-iam-lines.ipynb
+++ b/notebooks/03-look-at-iam-lines.ipynb
@@ -228,7 +228,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -242,7 +242,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.9.1"
+ "version": "3.9.6"
}
},
"nbformat": 4,
diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb
index 315b7bf..dd3a934 100644
--- a/notebooks/03-look-at-iam-paragraphs.ipynb
+++ b/notebooks/03-look-at-iam-paragraphs.ipynb
@@ -317,9 +317,7 @@
"cell_type": "code",
"execution_count": 61,
"id": "e7778ae2",
- "metadata": {
- "scrolled": false
- },
+ "metadata": {},
"outputs": [
{
"data": {
@@ -507,7 +505,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -521,7 +519,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.9.5"
+ "version": "3.9.6"
}
},
"nbformat": 4,
diff --git a/notebooks/04b-look-at-iam-paragraphs-predictions.ipynb b/notebooks/04b-look-at-iam-paragraphs-predictions.ipynb
index 5662eb1..40d371c 100644
--- a/notebooks/04b-look-at-iam-paragraphs-predictions.ipynb
+++ b/notebooks/04b-look-at-iam-paragraphs-predictions.ipynb
@@ -99,9 +99,7 @@
{
"cell_type": "code",
"execution_count": 39,
- "metadata": {
- "scrolled": false
- },
+ "metadata": {},
"outputs": [
{
"data": {
@@ -247,7 +245,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -261,7 +259,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.8.2"
+ "version": "3.9.6"
}
},
"nbformat": 4,
diff --git a/notebooks/04b-look-at-iam-paragraphs.ipynb b/notebooks/04b-look-at-iam-paragraphs.ipynb
index 11ebddf..414ea85 100644
--- a/notebooks/04b-look-at-iam-paragraphs.ipynb
+++ b/notebooks/04b-look-at-iam-paragraphs.ipynb
@@ -97,9 +97,7 @@
{
"cell_type": "code",
"execution_count": 48,
- "metadata": {
- "scrolled": false
- },
+ "metadata": {},
"outputs": [
{
"data": {
@@ -242,7 +240,7 @@
],
"metadata": {
"kernelspec": {
- "display_name": "Python 3",
+ "display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@@ -256,7 +254,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.9.1"
+ "version": "3.9.6"
}
},
"nbformat": 4,
diff --git a/notebooks/05c-test-model-end-to-end.ipynb b/notebooks/05c-test-model-end-to-end.ipynb
index a0b4ee9..e2ccb3c 100644
--- a/notebooks/05c-test-model-end-to-end.ipynb
+++ b/notebooks/05c-test-model-end-to-end.ipynb
@@ -19,43 +19,13 @@
"from importlib.util import find_spec\n",
"if find_spec(\"text_recognizer\") is None:\n",
" import sys\n",
- " sys.path.append('..')"
+ " sys.path.append('..')\n",
+ " "
]
},
{
"cell_type": "code",
"execution_count": 2,
- "id": "2ab9ac7a-a288-45bc-bfb7-8579a3a38d93",
- "metadata": {},
- "outputs": [],
- "source": [
- "import torch.nn.functional as F"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "ecab65ba-5aa0-45f0-99d7-e837464185ac",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<function torch.nn.functional.softmax(input: torch.Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> torch.Tensor>"
- ]
- },
- "execution_count": 4,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "F.softmax"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
"id": "3e812a1e",
"metadata": {},
"outputs": [],
@@ -65,309 +35,231 @@
},
{
"cell_type": "code",
- "execution_count": 10,
- "id": "a42a7988",
+ "execution_count": 3,
+ "id": "d3a6146b-94b1-4618-a4e4-00f8e23ffdb0",
"metadata": {},
"outputs": [],
"source": [
- "@attr.s\n",
- "class C(object):\n",
- " d = {2: \"hej\"}\n",
- " x: F.softmax = attr.ib(init=False, default=F.softmax)\n",
- " @x.validator\n",
- " def check(self, attribute, value):\n",
- " print(attribute)\n",
- " print(self.x)"
+ "from hydra import compose, initialize\n",
+ "from omegaconf import OmegaConf\n",
+ "from hydra.utils import instantiate"
]
},
{
"cell_type": "code",
- "execution_count": 14,
- "id": "660a7b1f",
+ "execution_count": 4,
+ "id": "9c797159-845e-42c6-bd65-1c976ad627cd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "Attribute(name='x', default=<function softmax at 0x7fb624839ca0>, validator=<function C.check at 0x7fb622ce2040>, repr=True, eq=True, eq_key=None, order=True, order_key=None, hash=None, init=False, metadata=mappingproxy({}), type=<function softmax at 0x7fb624839ca0>, converter=None, kw_only=False, inherited=False, on_setattr=None)\n",
- "<function softmax at 0x7fb624839ca0>\n"
+ "encoder:\n",
+ " _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet\n",
+ " arch: b0\n",
+ " out_channels: 1280\n",
+ " stochastic_dropout_rate: 0.2\n",
+ " bn_momentum: 0.99\n",
+ " bn_eps: 0.001\n",
+ "decoder:\n",
+ " _target_: text_recognizer.networks.transformer.Decoder\n",
+ " dim: 256\n",
+ " depth: 2\n",
+ " num_heads: 8\n",
+ " attn_fn: text_recognizer.networks.transformer.attention.Attention\n",
+ " attn_kwargs:\n",
+ " num_heads: 8\n",
+ " dim_head: 64\n",
+ " dropout_rate: 0.2\n",
+ " norm_fn: torch.nn.LayerNorm\n",
+ " ff_fn: text_recognizer.networks.transformer.mlp.FeedForward\n",
+ " ff_kwargs:\n",
+ " dim: 256\n",
+ " dim_out: null\n",
+ " expansion_factor: 4\n",
+ " glu: true\n",
+ " dropout_rate: 0.2\n",
+ " rotary_emb: null\n",
+ " rotary_emb_dim: null\n",
+ " cross_attend: true\n",
+ " pre_norm: true\n",
+ "_target_: text_recognizer.networks.conv_transformer.ConvTransformer\n",
+ "input_dims:\n",
+ "- 1\n",
+ "- 576\n",
+ "- 640\n",
+ "hidden_dim: 256\n",
+ "dropout_rate: 0.2\n",
+ "max_output_len: 682\n",
+ "num_classes: 1004\n",
+ "start_token: <s>\n",
+ "end_token: <e>\n",
+ "pad_token: <p>\n",
+ "\n",
+ "{'encoder': {'_target_': 'text_recognizer.networks.encoders.efficientnet.EfficientNet', 'arch': 'b0', 'out_channels': 1280, 'stochastic_dropout_rate': 0.2, 'bn_momentum': 0.99, 'bn_eps': 0.001}, 'decoder': {'_target_': 'text_recognizer.networks.transformer.Decoder', 'dim': 256, 'depth': 2, 'num_heads': 8, 'attn_fn': 'text_recognizer.networks.transformer.attention.Attention', 'attn_kwargs': {'num_heads': 8, 'dim_head': 64, 'dropout_rate': 0.2}, 'norm_fn': 'torch.nn.LayerNorm', 'ff_fn': 'text_recognizer.networks.transformer.mlp.FeedForward', 'ff_kwargs': {'dim': 256, 'dim_out': None, 'expansion_factor': 4, 'glu': True, 'dropout_rate': 0.2}, 'rotary_emb': None, 'rotary_emb_dim': None, 'cross_attend': True, 'pre_norm': True}, '_target_': 'text_recognizer.networks.conv_transformer.ConvTransformer', 'input_dims': [1, 576, 640], 'hidden_dim': 256, 'dropout_rate': 0.2, 'max_output_len': 682, 'num_classes': 1004, 'start_token': '<s>', 'end_token': '<e>', 'pad_token': '<p>'}\n"
]
}
],
"source": [
- "c = C()"
+ "# context initialization\n",
+ "with initialize(config_path=\"../training/conf/network/\", job_name=\"test_app\"):\n",
+ " cfg = compose(config_name=\"conv_transformer\")\n",
+ " print(OmegaConf.to_yaml(cfg))\n",
+ " print(cfg)"
]
},
{
"cell_type": "code",
- "execution_count": 12,
- "id": "9c3d1163",
+ "execution_count": 5,
+ "id": "cdb895b6-8949-4318-8a40-06fb5ed5e8d6",
"metadata": {},
"outputs": [
{
- "data": {
- "text/plain": [
- "<function torch.nn.functional.softmax(input: torch.Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> torch.Tensor>"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "_target_: text_recognizer.data.mappings.WordPieceMapping\n",
+ "num_features: 1000\n",
+ "tokens: iamdb_1kwp_tokens_1000.txt\n",
+ "lexicon: iamdb_1kwp_lex_1000.txt\n",
+ "data_dir: null\n",
+ "use_words: false\n",
+ "prepend_wordsep: false\n",
+ "special_tokens:\n",
+ "- <s>\n",
+ "- <e>\n",
+ "- <p>\n",
+ "extra_symbols:\n",
+ "- '\n",
+ "\n",
+ " '\n",
+ "\n",
+ "{'_target_': 'text_recognizer.data.mappings.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['<s>', '<e>', '<p>'], 'extra_symbols': ['\\n']}\n"
+ ]
}
],
"source": [
- "c.x"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 26,
- "id": "b3c8879c",
- "metadata": {},
- "outputs": [],
- "source": [
- "from torch import nn"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 33,
- "id": "2f5f6b75",
- "metadata": {},
- "outputs": [],
- "source": [
- "l = nn.ModuleList([])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 34,
- "id": "9938ec53",
- "metadata": {},
- "outputs": [],
- "source": [
- "f = nn.Linear(10, 10)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 35,
- "id": "fc49db78",
- "metadata": {},
- "outputs": [],
- "source": [
- "for _ in range(10):\n",
- " l.append(f)"
+ "with initialize(config_path=\"../training/conf/mapping/\", job_name=\"test_app\"):\n",
+ " cfg = compose(config_name=\"word_piece\")\n",
+ " print(OmegaConf.to_yaml(cfg))\n",
+ " print(cfg)"
]
},
{
"cell_type": "code",
- "execution_count": 36,
- "id": "e799a9dc",
+ "execution_count": 6,
+ "id": "b6181656-580a-4d96-8495-b6bb510944cc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "ModuleList(\n",
- " (0): Linear(in_features=10, out_features=10, bias=True)\n",
- " (1): Linear(in_features=10, out_features=10, bias=True)\n",
- " (2): Linear(in_features=10, out_features=10, bias=True)\n",
- " (3): Linear(in_features=10, out_features=10, bias=True)\n",
- " (4): Linear(in_features=10, out_features=10, bias=True)\n",
- " (5): Linear(in_features=10, out_features=10, bias=True)\n",
- " (6): Linear(in_features=10, out_features=10, bias=True)\n",
- " (7): Linear(in_features=10, out_features=10, bias=True)\n",
- " (8): Linear(in_features=10, out_features=10, bias=True)\n",
- " (9): Linear(in_features=10, out_features=10, bias=True)\n",
- ")"
+ "{'_target_': 'text_recognizer.data.mappings.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['<s>', '<e>', '<p>'], 'extra_symbols': ['\\n']}"
]
},
- "execution_count": 36,
+ "execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "\n",
- "l"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 37,
- "id": "17213dfb",
- "metadata": {},
- "outputs": [
- {
- "ename": "AttributeError",
- "evalue": "'Linear' object has no attribute 'copy'",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m/tmp/ipykernel_31696/2302067867.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mff\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1128\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1129\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1130\u001b[0;31m raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0m\u001b[1;32m 1131\u001b[0m type(self).__name__, name))\n\u001b[1;32m 1132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mAttributeError\u001b[0m: 'Linear' object has no attribute 'copy'"
- ]
- }
- ],
- "source": [
- "ff = f.copy()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 38,
- "id": "60277c26",
- "metadata": {},
- "outputs": [],
- "source": [
- "from copy import deepcopy"
+ "cfg"
]
},
{
"cell_type": "code",
- "execution_count": 39,
- "id": "cf86534a",
+ "execution_count": null,
+ "id": "5cd80d84-3ae5-4bb4-bc00-0dac7b22e134",
"metadata": {},
"outputs": [],
- "source": [
- "ff = deepcopy(f)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 43,
- "id": "2a260dc8",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "140011688939472"
- ]
- },
- "execution_count": 43,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "id(ff)"
- ]
+ "source": []
},
{
"cell_type": "code",
- "execution_count": 42,
- "id": "6dcf5f63",
+ "execution_count": 8,
+ "id": "0c123c76-ed90-49fa-903b-70ad60a33f16",
"metadata": {},
"outputs": [
{
- "data": {
- "text/plain": [
- "140011688936544"
- ]
- },
- "execution_count": 42,
- "metadata": {},
- "output_type": "execute_result"
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2021-07-29 23:02:56.650 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n"
+ ]
}
],
"source": [
- "id(f)"
+ "mapping = instantiate(cfg)"
]
},
{
"cell_type": "code",
- "execution_count": 44,
- "id": "74958f8d",
+ "execution_count": 9,
+ "id": "ff6c57f0-3c96-418e-8192-cd12bf79c073",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "140011688936544"
+ "tensor([1002])"
]
},
- "execution_count": 44,
+ "execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "id(l[0])"
+ "mapping.get_index(\"<p>\")"
]
},
{
"cell_type": "code",
- "execution_count": 45,
- "id": "bcceabd5",
+ "execution_count": 10,
+ "id": "348391ec-0cf7-49f6-bac2-26bc8c966705",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "140011688936544"
+ "1006"
]
},
- "execution_count": 45,
+ "execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "id(l[1])"
+ "len(mapping)"
]
},
{
"cell_type": "code",
- "execution_count": 58,
- "id": "191a0b03",
+ "execution_count": 15,
+ "id": "67673bf2-79c6-4010-93dd-9c9ba8f9a90e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "'nn'"
+ "tensor([1003])"
]
},
- "execution_count": 58,
+ "execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "\".\".join(\"nn.LayerNorm\".split(\".\")[:-1])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 60,
- "id": "4ff8ae08",
- "metadata": {},
- "outputs": [
- {
- "ename": "AttributeError",
- "evalue": "'str' object has no attribute 'LayerNorm'",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m/tmp/ipykernel_31696/162121485.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"torch.nn\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"LayerNorm\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
- "\u001b[0;31mAttributeError\u001b[0m: 'str' object has no attribute 'LayerNorm'"
- ]
- }
- ],
- "source": [
- "getattr(\"torch.nn\", \"LayerNorm\")"
+ "mapping.get_index(\"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "4d536bf2",
+ "id": "8923ea1e-b571-42ee-bfd7-4984aa70644f",
"metadata": {},
"outputs": [],
"source": []
diff --git a/text_recognizer/criterions/label_smoothing_loss.py b/text_recognizer/criterions/label_smoothing.py
index 40a7609..40a7609 100644
--- a/text_recognizer/criterions/label_smoothing_loss.py
+++ b/text_recognizer/criterions/label_smoothing.py
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index 4318dfb..c26f1c9 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -29,6 +29,7 @@ class BaseDataset(Dataset):
super().__init__()
def __attrs_post_init__(self) -> None:
+ # TODO: refactor this
if len(self.data) != len(self.targets):
raise ValueError("Data and targets must be of equal length.")
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index d51a42a..2d0ac29 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -46,7 +46,7 @@ class EMNIST(BaseDataModule):
EMNIST ByClass: 814,255 characters. 62 unbalanced classes.
"""
- train_fraction: float = attr.ib()
+ train_fraction: float = attr.ib(default=0.8)
transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()]))
def __attrs_post_init__(self) -> None:
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 886e37e..58c7369 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -13,23 +13,24 @@ from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
@attr.s(auto_attribs=True)
class IAMExtendedParagraphs(BaseDataModule):
- train_fraction: float = attr.ib()
+ augment: bool = attr.ib(default=True)
+ train_fraction: float = attr.ib(default=0.8)
word_pieces: bool = attr.ib(default=False)
def __attrs_post_init__(self) -> None:
self.iam_paragraphs = IAMParagraphs(
- self.batch_size,
- self.num_workers,
- self.train_fraction,
- self.augment,
- self.word_pieces,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ train_fraction=self.train_fraction,
+ augment=self.augment,
+ word_pieces=self.word_pieces,
)
self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
- self.batch_size,
- self.num_workers,
- self.train_fraction,
- self.augment,
- self.word_pieces,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ train_fraction=self.train_fraction,
+ augment=self.augment,
+ word_pieces=self.word_pieces,
)
self.dims = self.iam_paragraphs.dims
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index e45e5c8..705cfa3 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -34,6 +34,7 @@ SEED = 4711
PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_lines"
IMAGE_HEIGHT = 56
IMAGE_WIDTH = 1024
+MAX_LABEL_LENGTH = 89
@attr.s(auto_attribs=True)
@@ -42,11 +43,12 @@ class IAMLines(BaseDataModule):
augment: bool = attr.ib(default=True)
fraction: float = attr.ib(default=0.8)
+ dims: Tuple[int, int, int] = attr.ib(init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH))
+ output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1))
def __attrs_post_init__(self) -> None:
+ # TODO: refactor this
self.mapping, self.inverse_mapping, _ = emnist_mapping()
- self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
- self.output_dims = (89, 1)
def prepare_data(self) -> None:
"""Creates the IAM lines dataset if not existing."""
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index bdfb490..9977978 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -41,6 +41,8 @@ class IAMParagraphs(BaseDataModule):
augment: bool = attr.ib(default=True)
train_fraction: float = attr.ib(default=0.8)
word_pieces: bool = attr.ib(default=False)
+ dims: Tuple[int, int, int] = attr.ib(init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH))
+ output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1))
def __attrs_post_init__(self) -> None:
self.mapping, self.inverse_mapping, _ = emnist_mapping(
@@ -49,11 +51,6 @@ class IAMParagraphs(BaseDataModule):
if self.word_pieces:
self.mapping = WordPieceMapping()
- self.train_fraction = train_fraction
-
- self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH)
- self.output_dims = (MAX_LABEL_LENGTH, 1)
-
def prepare_data(self) -> None:
"""Create data for training/testing."""
if PROCESSED_DATA_DIRNAME.exists():
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index 00fa2b6..a3697e7 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -2,6 +2,7 @@
import random
from typing import Any, List, Sequence, Tuple
+import attr
from loguru import logger
import numpy as np
from PIL import Image
@@ -33,19 +34,10 @@ PROCESSED_DATA_DIRNAME = (
)
+@attr.s(auto_attribs=True)
class IAMSyntheticParagraphs(IAMParagraphs):
"""IAM Handwriting database of synthetic paragraphs."""
- def __init__(
- self,
- batch_size: int = 16,
- num_workers: int = 0,
- train_fraction: float = 0.8,
- augment: bool = True,
- word_pieces: bool = False,
- ) -> None:
- super().__init__(batch_size, num_workers, train_fraction, augment, word_pieces)
-
def prepare_data(self) -> None:
"""Prepare IAM lines to be used to generate paragraphs."""
if PROCESSED_DATA_DIRNAME.exists():
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index f95df0f..3b83056 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -3,20 +3,25 @@ from typing import Any, Dict, List, Tuple, Type
import attr
import hydra
-import loguru.logger as log
+from loguru import logger as log
from omegaconf import DictConfig
-import pytorch_lightning as LightningModule
+from pytorch_lightning import LightningModule
import torch
from torch import nn
from torch import Tensor
import torchmetrics
+from text_recognizer.networks.base import BaseNetwork
+
@attr.s
class BaseLitModel(LightningModule):
"""Abstract PyTorch Lightning class."""
- network: Type[nn.Module] = attr.ib()
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ network: Type[BaseNetwork] = attr.ib()
criterion_config: DictConfig = attr.ib(converter=DictConfig)
optimizer_config: DictConfig = attr.ib(converter=DictConfig)
lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)
@@ -24,23 +29,13 @@ class BaseLitModel(LightningModule):
interval: str = attr.ib()
monitor: str = attr.ib(default="val/loss")
- loss_fn = attr.ib(init=False)
-
- train_acc = attr.ib(init=False)
- val_acc = attr.ib(init=False)
- test_acc = attr.ib(init=False)
-
- def __attrs_pre_init__(self) -> None:
- super().__init__()
-
- def __attrs_post_init__(self) -> None:
- self.loss_fn = self._configure_criterion()
+ loss_fn: Type[nn.Module] = attr.ib(init=False)
- # Accuracy metric
- self.train_acc = torchmetrics.Accuracy()
- self.val_acc = torchmetrics.Accuracy()
- self.test_acc = torchmetrics.Accuracy()
+ train_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy())
+ val_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy())
+ test_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy())
+ @loss_fn.default
def configure_criterion(self) -> Type[nn.Module]:
"""Returns a loss functions."""
log.info(f"Instantiating criterion <{self.criterion_config._target_}>")
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 8c9fe8a..f5cb491 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,13 +1,11 @@
"""PyTorch Lightning model for base Transformers."""
-from typing import Dict, List, Optional, Union, Tuple, Type
+from typing import Dict, List, Optional, Sequence, Union, Tuple, Type
import attr
import hydra
from omegaconf import DictConfig
from torch import nn, Tensor
-from text_recognizer.data.emnist import emnist_mapping
-from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import BaseLitModel
@@ -16,30 +14,18 @@ from text_recognizer.models.base import BaseLitModel
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- mapping_config: DictConfig = attr.ib(converter=DictConfig)
+ ignore_tokens: Sequence[str] = attr.ib(default=("<s>", "<e>", "<p>",))
+ val_cer: CharacterErrorRate = attr.ib(init=False)
+ test_cer: CharacterErrorRate = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
- self.mapping, ignore_tokens = self._configure_mapping()
- self.val_cer = CharacterErrorRate(ignore_tokens)
- self.test_cer = CharacterErrorRate(ignore_tokens)
+ self.val_cer = CharacterErrorRate(self.ignore_tokens)
+ self.test_cer = CharacterErrorRate(self.ignore_tokens)
def forward(self, data: Tensor) -> Tensor:
"""Forward pass with the transformer network."""
return self.network.predict(data)
- @staticmethod
- def _configure_mapping() -> Tuple[Type[AbstractMapping], List[int]]:
- """Configure mapping."""
- # TODO: Fix me!!!
- # Load config with hydra
- mapping, inverse_mapping, _ = emnist_mapping(["\n"])
- start_index = inverse_mapping["<s>"]
- end_index = inverse_mapping["<e>"]
- pad_index = inverse_mapping["<p>"]
- ignore_tokens = [start_index, end_index, pad_index]
- # TODO: add case for sentence pieces
- return mapping, ignore_tokens
-
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
"""Training step."""
data, targets = batch
diff --git a/text_recognizer/networks/base.py b/text_recognizer/networks/base.py
new file mode 100644
index 0000000..07b6a32
--- /dev/null
+++ b/text_recognizer/networks/base.py
@@ -0,0 +1,18 @@
+"""Base network with required methods."""
+from abc import abstractmethod
+
+import attr
+from torch import nn, Tensor
+
+
+@attr.s
+class BaseNetwork(nn.Module):
+ """Base network."""
+
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ @abstractmethod
+ def predict(self, x: Tensor) -> Tensor:
+ """Return token indices for predictions."""
+ ...
diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/conv_transformer.py
index ce7ec43..4acdc36 100644
--- a/text_recognizer/networks/cnn_tranformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -7,6 +7,7 @@ import torch
from torch import nn, Tensor
from text_recognizer.data.mappings import AbstractMapping
+from text_recognizer.networks.base import BaseNetwork
from text_recognizer.networks.encoders.efficientnet import EfficientNet
from text_recognizer.networks.transformer.layers import Decoder
from text_recognizer.networks.transformer.positional_encodings import (
@@ -15,39 +16,37 @@ from text_recognizer.networks.transformer.positional_encodings import (
)
-@attr.s
-class Reader(nn.Module):
- def __attrs_pre_init__(self) -> None:
- super().__init__()
-
+@attr.s(auto_attribs=True)
+class ConvTransformer(BaseNetwork):
# Parameters and placeholders,
input_dims: Tuple[int, int, int] = attr.ib()
hidden_dim: int = attr.ib()
dropout_rate: float = attr.ib()
max_output_len: int = attr.ib()
num_classes: int = attr.ib()
- padding_idx: int = attr.ib()
start_token: str = attr.ib()
- start_index: int = attr.ib(init=False)
+ start_index: Tensor = attr.ib(init=False)
end_token: str = attr.ib()
- end_index: int = attr.ib(init=False)
+ end_index: Tensor = attr.ib(init=False)
pad_token: str = attr.ib()
- pad_index: int = attr.ib(init=False)
+ pad_index: Tensor = attr.ib(init=False)
# Modules.
encoder: EfficientNet = attr.ib()
decoder: Decoder = attr.ib()
+ mapping: Type[AbstractMapping] = attr.ib()
+
latent_encoder: nn.Sequential = attr.ib(init=False)
token_embedding: nn.Embedding = attr.ib(init=False)
token_pos_encoder: PositionalEncoding = attr.ib(init=False)
head: nn.Linear = attr.ib(init=False)
- mapping: Type[AbstractMapping] = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
- self.start_index = int(self.mapping.get_index(self.start_token))
- self.end_index = int(self.mapping.get_index(self.end_token))
- self.pad_index = int(self.mapping.get_index(self.pad_token))
+ self.start_index = self.mapping.get_index(self.start_token)
+ self.end_index = self.mapping.get_index(self.end_token)
+ self.pad_index = self.mapping.get_index(self.pad_token)
+
# Latent projector for down sampling number of filters and 2d
# positional encoding.
self.latent_encoder = nn.Sequential(
@@ -130,7 +129,7 @@ class Reader(nn.Module):
Returns:
Tensor: Sequence of word piece embeddings.
"""
- context_mask = context != self.padding_idx
+ context_mask = context != self.pad_index
context = self.token_embedding(context) * math.sqrt(self.hidden_dim)
context = self.token_pos_encoder(context)
out = self.decoder(x=context, context=z, mask=context_mask)
diff --git a/training/conf/criterion/label_smoothing.yaml b/training/conf/criterion/label_smoothing.yaml
index e69de29..ee47c59 100644
--- a/training/conf/criterion/label_smoothing.yaml
+++ b/training/conf/criterion/label_smoothing.yaml
@@ -0,0 +1,4 @@
+_target_: text_recognizer.criterion.label_smoothing
+label_smoothing: 0.1
+vocab_size: 1006
+ignore_index: 1002
diff --git a/training/conf/mapping/word_piece.yaml b/training/conf/mapping/word_piece.yaml
new file mode 100644
index 0000000..39e2ba4
--- /dev/null
+++ b/training/conf/mapping/word_piece.yaml
@@ -0,0 +1,9 @@
+_target_: text_recognizer.data.mappings.WordPieceMapping
+num_features: 1000
+tokens: iamdb_1kwp_tokens_1000.txt
+lexicon: iamdb_1kwp_lex_1000.txt
+data_dir: null
+use_words: false
+prepend_wordsep: false
+special_tokens: ["<s>", "<e>", "<p>"]
+extra_symbols: ["\n"]
diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml
new file mode 100644
index 0000000..4e04b85
--- /dev/null
+++ b/training/conf/model/lit_transformer.yaml
@@ -0,0 +1,4 @@
+_target_: text_recognizer.models.transformer.TransformerLitModel
+interval: null
+monitor: val/loss
+ignore_tokens: ["<s>", "<e>", "<p>"]
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
new file mode 100644
index 0000000..f72e030
--- /dev/null
+++ b/training/conf/network/conv_transformer.yaml
@@ -0,0 +1,13 @@
+defaults:
+ - encoder: efficientnet
+ - decoder: transformer_decoder
+
+_target_: text_recognizer.networks.conv_transformer.ConvTransformer
+input_dims: [1, 576, 640]
+hidden_dim: 256
+dropout_rate: 0.2
+max_output_len: 682
+num_classes: 1004
+start_token: <s>
+end_token: <e>
+pad_token: <p>