From 34098ccbbbf6379c0bd29a987440b8479c743746 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 29 Jul 2021 23:59:52 +0200 Subject: Configs, refactor with attrs, fix attr bug in iam --- notebooks/00-scratch-pad.ipynb | 10 +- notebooks/01-look-at-emnist.ipynb | 4 +- notebooks/02b-look-at-emnist-lines.ipynb | 8 +- notebooks/03-look-at-iam-lines.ipynb | 4 +- notebooks/03-look-at-iam-paragraphs.ipynb | 8 +- .../04b-look-at-iam-paragraphs-predictions.ipynb | 8 +- notebooks/04b-look-at-iam-paragraphs.ipynb | 8 +- notebooks/05c-test-model-end-to-end.ipynb | 340 +++++++-------------- text_recognizer/criterions/label_smoothing.py | 42 +++ text_recognizer/criterions/label_smoothing_loss.py | 42 --- text_recognizer/data/base_dataset.py | 1 + text_recognizer/data/emnist.py | 2 +- text_recognizer/data/iam_extended_paragraphs.py | 23 +- text_recognizer/data/iam_lines.py | 6 +- text_recognizer/data/iam_paragraphs.py | 7 +- text_recognizer/data/iam_synthetic_paragraphs.py | 12 +- text_recognizer/models/base.py | 31 +- text_recognizer/models/transformer.py | 26 +- text_recognizer/networks/base.py | 18 ++ text_recognizer/networks/cnn_tranformer.py | 202 ------------ text_recognizer/networks/conv_transformer.py | 201 ++++++++++++ training/conf/criterion/label_smoothing.yaml | 4 + training/conf/mapping/word_piece.yaml | 9 + training/conf/model/lit_transformer.yaml | 4 + training/conf/network/conv_transformer.yaml | 13 + 25 files changed, 467 insertions(+), 566 deletions(-) create mode 100644 text_recognizer/criterions/label_smoothing.py delete mode 100644 text_recognizer/criterions/label_smoothing_loss.py create mode 100644 text_recognizer/networks/base.py delete mode 100644 text_recognizer/networks/cnn_tranformer.py create mode 100644 text_recognizer/networks/conv_transformer.py create mode 100644 training/conf/mapping/word_piece.yaml create mode 100644 training/conf/model/lit_transformer.yaml create mode 100644 training/conf/network/conv_transformer.yaml diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb index 2c98064..0350727 100644 --- a/notebooks/00-scratch-pad.ipynb +++ b/notebooks/00-scratch-pad.ipynb @@ -49,9 +49,7 @@ { "cell_type": "code", "execution_count": 7, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "en = EfficientNet(\"b0\")" @@ -268,9 +266,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "summary(en, (1, 224, 224));" @@ -1157,7 +1153,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, diff --git a/notebooks/01-look-at-emnist.ipynb b/notebooks/01-look-at-emnist.ipynb index 5b5310e..1ca06c5 100644 --- a/notebooks/01-look-at-emnist.ipynb +++ b/notebooks/01-look-at-emnist.ipynb @@ -106,7 +106,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -120,7 +120,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.5" + "version": "3.9.6" } }, "nbformat": 4, diff --git a/notebooks/02b-look-at-emnist-lines.ipynb b/notebooks/02b-look-at-emnist-lines.ipynb index 93893f9..89045a4 100644 --- a/notebooks/02b-look-at-emnist-lines.ipynb +++ b/notebooks/02b-look-at-emnist-lines.ipynb @@ -136,9 +136,7 @@ { "cell_type": "code", "execution_count": 9, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [ { "data": { @@ -270,7 +268,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -284,7 +282,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.9.6" } }, "nbformat": 4, diff --git a/notebooks/03-look-at-iam-lines.ipynb b/notebooks/03-look-at-iam-lines.ipynb index ab12642..383381d 100644 --- a/notebooks/03-look-at-iam-lines.ipynb +++ b/notebooks/03-look-at-iam-lines.ipynb @@ -228,7 +228,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -242,7 +242,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.9.6" } }, "nbformat": 4, diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb index 315b7bf..dd3a934 100644 --- a/notebooks/03-look-at-iam-paragraphs.ipynb +++ b/notebooks/03-look-at-iam-paragraphs.ipynb @@ -317,9 +317,7 @@ "cell_type": "code", "execution_count": 61, "id": "e7778ae2", - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [ { "data": { @@ -507,7 +505,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -521,7 +519,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.5" + "version": "3.9.6" } }, "nbformat": 4, diff --git a/notebooks/04b-look-at-iam-paragraphs-predictions.ipynb b/notebooks/04b-look-at-iam-paragraphs-predictions.ipynb index 5662eb1..40d371c 100644 --- a/notebooks/04b-look-at-iam-paragraphs-predictions.ipynb +++ b/notebooks/04b-look-at-iam-paragraphs-predictions.ipynb @@ -99,9 +99,7 @@ { "cell_type": "code", "execution_count": 39, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [ { "data": { @@ -247,7 +245,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -261,7 +259,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.2" + "version": "3.9.6" } }, "nbformat": 4, diff --git a/notebooks/04b-look-at-iam-paragraphs.ipynb b/notebooks/04b-look-at-iam-paragraphs.ipynb index 11ebddf..414ea85 100644 --- a/notebooks/04b-look-at-iam-paragraphs.ipynb +++ b/notebooks/04b-look-at-iam-paragraphs.ipynb @@ -97,9 +97,7 @@ { "cell_type": "code", "execution_count": 48, - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [ { "data": { @@ -242,7 +240,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -256,7 +254,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.1" + "version": "3.9.6" } }, "nbformat": 4, diff --git a/notebooks/05c-test-model-end-to-end.ipynb b/notebooks/05c-test-model-end-to-end.ipynb index a0b4ee9..e2ccb3c 100644 --- a/notebooks/05c-test-model-end-to-end.ipynb +++ b/notebooks/05c-test-model-end-to-end.ipynb @@ -19,43 +19,13 @@ "from importlib.util import find_spec\n", "if find_spec(\"text_recognizer\") is None:\n", " import sys\n", - " sys.path.append('..')" + " sys.path.append('..')\n", + " " ] }, { "cell_type": "code", "execution_count": 2, - "id": "2ab9ac7a-a288-45bc-bfb7-8579a3a38d93", - "metadata": {}, - "outputs": [], - "source": [ - "import torch.nn.functional as F" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "ecab65ba-5aa0-45f0-99d7-e837464185ac", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - " torch.Tensor>" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "F.softmax" - ] - }, - { - "cell_type": "code", - "execution_count": 5, "id": "3e812a1e", "metadata": {}, "outputs": [], @@ -65,309 +35,231 @@ }, { "cell_type": "code", - "execution_count": 10, - "id": "a42a7988", + "execution_count": 3, + "id": "d3a6146b-94b1-4618-a4e4-00f8e23ffdb0", "metadata": {}, "outputs": [], "source": [ - "@attr.s\n", - "class C(object):\n", - " d = {2: \"hej\"}\n", - " x: F.softmax = attr.ib(init=False, default=F.softmax)\n", - " @x.validator\n", - " def check(self, attribute, value):\n", - " print(attribute)\n", - " print(self.x)" + "from hydra import compose, initialize\n", + "from omegaconf import OmegaConf\n", + "from hydra.utils import instantiate" ] }, { "cell_type": "code", - "execution_count": 14, - "id": "660a7b1f", + "execution_count": 4, + "id": "9c797159-845e-42c6-bd65-1c976ad627cd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Attribute(name='x', default=, validator=, repr=True, eq=True, eq_key=None, order=True, order_key=None, hash=None, init=False, metadata=mappingproxy({}), type=, converter=None, kw_only=False, inherited=False, on_setattr=None)\n", - "\n" + "encoder:\n", + " _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet\n", + " arch: b0\n", + " out_channels: 1280\n", + " stochastic_dropout_rate: 0.2\n", + " bn_momentum: 0.99\n", + " bn_eps: 0.001\n", + "decoder:\n", + " _target_: text_recognizer.networks.transformer.Decoder\n", + " dim: 256\n", + " depth: 2\n", + " num_heads: 8\n", + " attn_fn: text_recognizer.networks.transformer.attention.Attention\n", + " attn_kwargs:\n", + " num_heads: 8\n", + " dim_head: 64\n", + " dropout_rate: 0.2\n", + " norm_fn: torch.nn.LayerNorm\n", + " ff_fn: text_recognizer.networks.transformer.mlp.FeedForward\n", + " ff_kwargs:\n", + " dim: 256\n", + " dim_out: null\n", + " expansion_factor: 4\n", + " glu: true\n", + " dropout_rate: 0.2\n", + " rotary_emb: null\n", + " rotary_emb_dim: null\n", + " cross_attend: true\n", + " pre_norm: true\n", + "_target_: text_recognizer.networks.conv_transformer.ConvTransformer\n", + "input_dims:\n", + "- 1\n", + "- 576\n", + "- 640\n", + "hidden_dim: 256\n", + "dropout_rate: 0.2\n", + "max_output_len: 682\n", + "num_classes: 1004\n", + "start_token: \n", + "end_token: \n", + "pad_token:

\n", + "\n", + "{'encoder': {'_target_': 'text_recognizer.networks.encoders.efficientnet.EfficientNet', 'arch': 'b0', 'out_channels': 1280, 'stochastic_dropout_rate': 0.2, 'bn_momentum': 0.99, 'bn_eps': 0.001}, 'decoder': {'_target_': 'text_recognizer.networks.transformer.Decoder', 'dim': 256, 'depth': 2, 'num_heads': 8, 'attn_fn': 'text_recognizer.networks.transformer.attention.Attention', 'attn_kwargs': {'num_heads': 8, 'dim_head': 64, 'dropout_rate': 0.2}, 'norm_fn': 'torch.nn.LayerNorm', 'ff_fn': 'text_recognizer.networks.transformer.mlp.FeedForward', 'ff_kwargs': {'dim': 256, 'dim_out': None, 'expansion_factor': 4, 'glu': True, 'dropout_rate': 0.2}, 'rotary_emb': None, 'rotary_emb_dim': None, 'cross_attend': True, 'pre_norm': True}, '_target_': 'text_recognizer.networks.conv_transformer.ConvTransformer', 'input_dims': [1, 576, 640], 'hidden_dim': 256, 'dropout_rate': 0.2, 'max_output_len': 682, 'num_classes': 1004, 'start_token': '', 'end_token': '', 'pad_token': '

'}\n" ] } ], "source": [ - "c = C()" + "# context initialization\n", + "with initialize(config_path=\"../training/conf/network/\", job_name=\"test_app\"):\n", + " cfg = compose(config_name=\"conv_transformer\")\n", + " print(OmegaConf.to_yaml(cfg))\n", + " print(cfg)" ] }, { "cell_type": "code", - "execution_count": 12, - "id": "9c3d1163", + "execution_count": 5, + "id": "cdb895b6-8949-4318-8a40-06fb5ed5e8d6", "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - " torch.Tensor>" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "_target_: text_recognizer.data.mappings.WordPieceMapping\n", + "num_features: 1000\n", + "tokens: iamdb_1kwp_tokens_1000.txt\n", + "lexicon: iamdb_1kwp_lex_1000.txt\n", + "data_dir: null\n", + "use_words: false\n", + "prepend_wordsep: false\n", + "special_tokens:\n", + "- \n", + "- \n", + "-

\n", + "extra_symbols:\n", + "- '\n", + "\n", + " '\n", + "\n", + "{'_target_': 'text_recognizer.data.mappings.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['', '', '

'], 'extra_symbols': ['\\n']}\n" + ] } ], "source": [ - "c.x" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "b3c8879c", - "metadata": {}, - "outputs": [], - "source": [ - "from torch import nn" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "2f5f6b75", - "metadata": {}, - "outputs": [], - "source": [ - "l = nn.ModuleList([])" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "9938ec53", - "metadata": {}, - "outputs": [], - "source": [ - "f = nn.Linear(10, 10)" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "id": "fc49db78", - "metadata": {}, - "outputs": [], - "source": [ - "for _ in range(10):\n", - " l.append(f)" + "with initialize(config_path=\"../training/conf/mapping/\", job_name=\"test_app\"):\n", + " cfg = compose(config_name=\"word_piece\")\n", + " print(OmegaConf.to_yaml(cfg))\n", + " print(cfg)" ] }, { "cell_type": "code", - "execution_count": 36, - "id": "e799a9dc", + "execution_count": 6, + "id": "b6181656-580a-4d96-8495-b6bb510944cc", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " (1): Linear(in_features=10, out_features=10, bias=True)\n", - " (2): Linear(in_features=10, out_features=10, bias=True)\n", - " (3): Linear(in_features=10, out_features=10, bias=True)\n", - " (4): Linear(in_features=10, out_features=10, bias=True)\n", - " (5): Linear(in_features=10, out_features=10, bias=True)\n", - " (6): Linear(in_features=10, out_features=10, bias=True)\n", - " (7): Linear(in_features=10, out_features=10, bias=True)\n", - " (8): Linear(in_features=10, out_features=10, bias=True)\n", - " (9): Linear(in_features=10, out_features=10, bias=True)\n", - ")" + "{'_target_': 'text_recognizer.data.mappings.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['', '', '

'], 'extra_symbols': ['\\n']}" ] }, - "execution_count": 36, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "\n", - "l" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "id": "17213dfb", - "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "'Linear' object has no attribute 'copy'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_31696/2302067867.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mff\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1128\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1129\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1130\u001b[0;31m raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0m\u001b[1;32m 1131\u001b[0m type(self).__name__, name))\n\u001b[1;32m 1132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAttributeError\u001b[0m: 'Linear' object has no attribute 'copy'" - ] - } - ], - "source": [ - "ff = f.copy()" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "id": "60277c26", - "metadata": {}, - "outputs": [], - "source": [ - "from copy import deepcopy" + "cfg" ] }, { "cell_type": "code", - "execution_count": 39, - "id": "cf86534a", + "execution_count": null, + "id": "5cd80d84-3ae5-4bb4-bc00-0dac7b22e134", "metadata": {}, "outputs": [], - "source": [ - "ff = deepcopy(f)" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "id": "2a260dc8", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "140011688939472" - ] - }, - "execution_count": 43, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "id(ff)" - ] + "source": [] }, { "cell_type": "code", - "execution_count": 42, - "id": "6dcf5f63", + "execution_count": 8, + "id": "0c123c76-ed90-49fa-903b-70ad60a33f16", "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "140011688936544" - ] - }, - "execution_count": 42, - "metadata": {}, - "output_type": "execute_result" + "name": "stderr", + "output_type": "stream", + "text": [ + "2021-07-29 23:02:56.650 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n" + ] } ], "source": [ - "id(f)" + "mapping = instantiate(cfg)" ] }, { "cell_type": "code", - "execution_count": 44, - "id": "74958f8d", + "execution_count": 9, + "id": "ff6c57f0-3c96-418e-8192-cd12bf79c073", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "140011688936544" + "tensor([1002])" ] }, - "execution_count": 44, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "id(l[0])" + "mapping.get_index(\"

\")" ] }, { "cell_type": "code", - "execution_count": 45, - "id": "bcceabd5", + "execution_count": 10, + "id": "348391ec-0cf7-49f6-bac2-26bc8c966705", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "140011688936544" + "1006" ] }, - "execution_count": 45, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "id(l[1])" + "len(mapping)" ] }, { "cell_type": "code", - "execution_count": 58, - "id": "191a0b03", + "execution_count": 15, + "id": "67673bf2-79c6-4010-93dd-9c9ba8f9a90e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'nn'" + "tensor([1003])" ] }, - "execution_count": 58, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "\".\".join(\"nn.LayerNorm\".split(\".\")[:-1])" - ] - }, - { - "cell_type": "code", - "execution_count": 60, - "id": "4ff8ae08", - "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "'str' object has no attribute 'LayerNorm'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_31696/162121485.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"torch.nn\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"LayerNorm\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m: 'str' object has no attribute 'LayerNorm'" - ] - } - ], - "source": [ - "getattr(\"torch.nn\", \"LayerNorm\")" + "mapping.get_index(\"\\n\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "4d536bf2", + "id": "8923ea1e-b571-42ee-bfd7-4984aa70644f", "metadata": {}, "outputs": [], "source": [] diff --git a/text_recognizer/criterions/label_smoothing.py b/text_recognizer/criterions/label_smoothing.py new file mode 100644 index 0000000..40a7609 --- /dev/null +++ b/text_recognizer/criterions/label_smoothing.py @@ -0,0 +1,42 @@ +"""Implementations of custom loss functions.""" +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F + + +class LabelSmoothingLoss(nn.Module): + """Label smoothing cross entropy loss.""" + + def __init__( + self, label_smoothing: float, vocab_size: int, ignore_index: int = -100 + ) -> None: + assert 0.0 < label_smoothing <= 1.0 + self.ignore_index = ignore_index + super().__init__() + + smoothing_value = label_smoothing / (vocab_size - 2) + one_hot = torch.full((vocab_size,), smoothing_value) + one_hot[self.ignore_index] = 0 + self.register_buffer("one_hot", one_hot.unsqueeze(0)) + + self.confidence = 1.0 - label_smoothing + + def forward(self, output: Tensor, targets: Tensor) -> Tensor: + """Computes the loss. + + Args: + output (Tensor): Predictions from the network. + targets (Tensor): Ground truth. + + Shapes: + outpus: Batch size x num classes + targets: Batch size + + Returns: + Tensor: Label smoothing loss. + """ + model_prob = self.one_hot.repeat(targets.size(0), 1) + model_prob.scatter_(1, targets.unsqueeze(1), self.confidence) + model_prob.masked_fill_((targets == self.ignore_index).unsqueeze(1), 0) + return F.kl_div(output, model_prob, reduction="sum") diff --git a/text_recognizer/criterions/label_smoothing_loss.py b/text_recognizer/criterions/label_smoothing_loss.py deleted file mode 100644 index 40a7609..0000000 --- a/text_recognizer/criterions/label_smoothing_loss.py +++ /dev/null @@ -1,42 +0,0 @@ -"""Implementations of custom loss functions.""" -import torch -from torch import nn -from torch import Tensor -import torch.nn.functional as F - - -class LabelSmoothingLoss(nn.Module): - """Label smoothing cross entropy loss.""" - - def __init__( - self, label_smoothing: float, vocab_size: int, ignore_index: int = -100 - ) -> None: - assert 0.0 < label_smoothing <= 1.0 - self.ignore_index = ignore_index - super().__init__() - - smoothing_value = label_smoothing / (vocab_size - 2) - one_hot = torch.full((vocab_size,), smoothing_value) - one_hot[self.ignore_index] = 0 - self.register_buffer("one_hot", one_hot.unsqueeze(0)) - - self.confidence = 1.0 - label_smoothing - - def forward(self, output: Tensor, targets: Tensor) -> Tensor: - """Computes the loss. - - Args: - output (Tensor): Predictions from the network. - targets (Tensor): Ground truth. - - Shapes: - outpus: Batch size x num classes - targets: Batch size - - Returns: - Tensor: Label smoothing loss. - """ - model_prob = self.one_hot.repeat(targets.size(0), 1) - model_prob.scatter_(1, targets.unsqueeze(1), self.confidence) - model_prob.masked_fill_((targets == self.ignore_index).unsqueeze(1), 0) - return F.kl_div(output, model_prob, reduction="sum") diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index 4318dfb..c26f1c9 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -29,6 +29,7 @@ class BaseDataset(Dataset): super().__init__() def __attrs_post_init__(self) -> None: + # TODO: refactor this if len(self.data) != len(self.targets): raise ValueError("Data and targets must be of equal length.") diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index d51a42a..2d0ac29 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -46,7 +46,7 @@ class EMNIST(BaseDataModule): EMNIST ByClass: 814,255 characters. 62 unbalanced classes. """ - train_fraction: float = attr.ib() + train_fraction: float = attr.ib(default=0.8) transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()])) def __attrs_post_init__(self) -> None: diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 886e37e..58c7369 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -13,23 +13,24 @@ from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs @attr.s(auto_attribs=True) class IAMExtendedParagraphs(BaseDataModule): - train_fraction: float = attr.ib() + augment: bool = attr.ib(default=True) + train_fraction: float = attr.ib(default=0.8) word_pieces: bool = attr.ib(default=False) def __attrs_post_init__(self) -> None: self.iam_paragraphs = IAMParagraphs( - self.batch_size, - self.num_workers, - self.train_fraction, - self.augment, - self.word_pieces, + batch_size=self.batch_size, + num_workers=self.num_workers, + train_fraction=self.train_fraction, + augment=self.augment, + word_pieces=self.word_pieces, ) self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( - self.batch_size, - self.num_workers, - self.train_fraction, - self.augment, - self.word_pieces, + batch_size=self.batch_size, + num_workers=self.num_workers, + train_fraction=self.train_fraction, + augment=self.augment, + word_pieces=self.word_pieces, ) self.dims = self.iam_paragraphs.dims diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index e45e5c8..705cfa3 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -34,6 +34,7 @@ SEED = 4711 PROCESSED_DATA_DIRNAME = BaseDataModule.data_dirname() / "processed" / "iam_lines" IMAGE_HEIGHT = 56 IMAGE_WIDTH = 1024 +MAX_LABEL_LENGTH = 89 @attr.s(auto_attribs=True) @@ -42,11 +43,12 @@ class IAMLines(BaseDataModule): augment: bool = attr.ib(default=True) fraction: float = attr.ib(default=0.8) + dims: Tuple[int, int, int] = attr.ib(init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)) + output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1)) def __attrs_post_init__(self) -> None: + # TODO: refactor this self.mapping, self.inverse_mapping, _ = emnist_mapping() - self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) - self.output_dims = (89, 1) def prepare_data(self) -> None: """Creates the IAM lines dataset if not existing.""" diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index bdfb490..9977978 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -41,6 +41,8 @@ class IAMParagraphs(BaseDataModule): augment: bool = attr.ib(default=True) train_fraction: float = attr.ib(default=0.8) word_pieces: bool = attr.ib(default=False) + dims: Tuple[int, int, int] = attr.ib(init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)) + output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1)) def __attrs_post_init__(self) -> None: self.mapping, self.inverse_mapping, _ = emnist_mapping( @@ -49,11 +51,6 @@ class IAMParagraphs(BaseDataModule): if self.word_pieces: self.mapping = WordPieceMapping() - self.train_fraction = train_fraction - - self.dims = (1, IMAGE_HEIGHT, IMAGE_WIDTH) - self.output_dims = (MAX_LABEL_LENGTH, 1) - def prepare_data(self) -> None: """Create data for training/testing.""" if PROCESSED_DATA_DIRNAME.exists(): diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index 00fa2b6..a3697e7 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -2,6 +2,7 @@ import random from typing import Any, List, Sequence, Tuple +import attr from loguru import logger import numpy as np from PIL import Image @@ -33,19 +34,10 @@ PROCESSED_DATA_DIRNAME = ( ) +@attr.s(auto_attribs=True) class IAMSyntheticParagraphs(IAMParagraphs): """IAM Handwriting database of synthetic paragraphs.""" - def __init__( - self, - batch_size: int = 16, - num_workers: int = 0, - train_fraction: float = 0.8, - augment: bool = True, - word_pieces: bool = False, - ) -> None: - super().__init__(batch_size, num_workers, train_fraction, augment, word_pieces) - def prepare_data(self) -> None: """Prepare IAM lines to be used to generate paragraphs.""" if PROCESSED_DATA_DIRNAME.exists(): diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index f95df0f..3b83056 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -3,20 +3,25 @@ from typing import Any, Dict, List, Tuple, Type import attr import hydra -import loguru.logger as log +from loguru import logger as log from omegaconf import DictConfig -import pytorch_lightning as LightningModule +from pytorch_lightning import LightningModule import torch from torch import nn from torch import Tensor import torchmetrics +from text_recognizer.networks.base import BaseNetwork + @attr.s class BaseLitModel(LightningModule): """Abstract PyTorch Lightning class.""" - network: Type[nn.Module] = attr.ib() + def __attrs_pre_init__(self) -> None: + super().__init__() + + network: Type[BaseNetwork] = attr.ib() criterion_config: DictConfig = attr.ib(converter=DictConfig) optimizer_config: DictConfig = attr.ib(converter=DictConfig) lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig) @@ -24,23 +29,13 @@ class BaseLitModel(LightningModule): interval: str = attr.ib() monitor: str = attr.ib(default="val/loss") - loss_fn = attr.ib(init=False) - - train_acc = attr.ib(init=False) - val_acc = attr.ib(init=False) - test_acc = attr.ib(init=False) - - def __attrs_pre_init__(self) -> None: - super().__init__() - - def __attrs_post_init__(self) -> None: - self.loss_fn = self._configure_criterion() + loss_fn: Type[nn.Module] = attr.ib(init=False) - # Accuracy metric - self.train_acc = torchmetrics.Accuracy() - self.val_acc = torchmetrics.Accuracy() - self.test_acc = torchmetrics.Accuracy() + train_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy()) + val_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy()) + test_acc: torchmetrics.Accuracy = attr.ib(init=False, default=torchmetrics.Accuracy()) + @loss_fn.default def configure_criterion(self) -> Type[nn.Module]: """Returns a loss functions.""" log.info(f"Instantiating criterion <{self.criterion_config._target_}>") diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 8c9fe8a..f5cb491 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,13 +1,11 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Dict, List, Optional, Union, Tuple, Type +from typing import Dict, List, Optional, Sequence, Union, Tuple, Type import attr import hydra from omegaconf import DictConfig from torch import nn, Tensor -from text_recognizer.data.emnist import emnist_mapping -from text_recognizer.data.mappings import AbstractMapping from text_recognizer.models.metrics import CharacterErrorRate from text_recognizer.models.base import BaseLitModel @@ -16,30 +14,18 @@ from text_recognizer.models.base import BaseLitModel class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - mapping_config: DictConfig = attr.ib(converter=DictConfig) + ignore_tokens: Sequence[str] = attr.ib(default=("", "", "

",)) + val_cer: CharacterErrorRate = attr.ib(init=False) + test_cer: CharacterErrorRate = attr.ib(init=False) def __attrs_post_init__(self) -> None: - self.mapping, ignore_tokens = self._configure_mapping() - self.val_cer = CharacterErrorRate(ignore_tokens) - self.test_cer = CharacterErrorRate(ignore_tokens) + self.val_cer = CharacterErrorRate(self.ignore_tokens) + self.test_cer = CharacterErrorRate(self.ignore_tokens) def forward(self, data: Tensor) -> Tensor: """Forward pass with the transformer network.""" return self.network.predict(data) - @staticmethod - def _configure_mapping() -> Tuple[Type[AbstractMapping], List[int]]: - """Configure mapping.""" - # TODO: Fix me!!! - # Load config with hydra - mapping, inverse_mapping, _ = emnist_mapping(["\n"]) - start_index = inverse_mapping[""] - end_index = inverse_mapping[""] - pad_index = inverse_mapping["

"] - ignore_tokens = [start_index, end_index, pad_index] - # TODO: add case for sentence pieces - return mapping, ignore_tokens - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: """Training step.""" data, targets = batch diff --git a/text_recognizer/networks/base.py b/text_recognizer/networks/base.py new file mode 100644 index 0000000..07b6a32 --- /dev/null +++ b/text_recognizer/networks/base.py @@ -0,0 +1,18 @@ +"""Base network with required methods.""" +from abc import abstractmethod + +import attr +from torch import nn, Tensor + + +@attr.s +class BaseNetwork(nn.Module): + """Base network.""" + + def __attrs_pre_init__(self) -> None: + super().__init__() + + @abstractmethod + def predict(self, x: Tensor) -> Tensor: + """Return token indices for predictions.""" + ... diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py deleted file mode 100644 index ce7ec43..0000000 --- a/text_recognizer/networks/cnn_tranformer.py +++ /dev/null @@ -1,202 +0,0 @@ -"""Vision transformer for character recognition.""" -import math -from typing import Tuple, Type - -import attr -import torch -from torch import nn, Tensor - -from text_recognizer.data.mappings import AbstractMapping -from text_recognizer.networks.encoders.efficientnet import EfficientNet -from text_recognizer.networks.transformer.layers import Decoder -from text_recognizer.networks.transformer.positional_encodings import ( - PositionalEncoding, - PositionalEncoding2D, -) - - -@attr.s -class Reader(nn.Module): - def __attrs_pre_init__(self) -> None: - super().__init__() - - # Parameters and placeholders, - input_dims: Tuple[int, int, int] = attr.ib() - hidden_dim: int = attr.ib() - dropout_rate: float = attr.ib() - max_output_len: int = attr.ib() - num_classes: int = attr.ib() - padding_idx: int = attr.ib() - start_token: str = attr.ib() - start_index: int = attr.ib(init=False) - end_token: str = attr.ib() - end_index: int = attr.ib(init=False) - pad_token: str = attr.ib() - pad_index: int = attr.ib(init=False) - - # Modules. - encoder: EfficientNet = attr.ib() - decoder: Decoder = attr.ib() - latent_encoder: nn.Sequential = attr.ib(init=False) - token_embedding: nn.Embedding = attr.ib(init=False) - token_pos_encoder: PositionalEncoding = attr.ib(init=False) - head: nn.Linear = attr.ib(init=False) - mapping: Type[AbstractMapping] = attr.ib(init=False) - - def __attrs_post_init__(self) -> None: - """Post init configuration.""" - self.start_index = int(self.mapping.get_index(self.start_token)) - self.end_index = int(self.mapping.get_index(self.end_token)) - self.pad_index = int(self.mapping.get_index(self.pad_token)) - # Latent projector for down sampling number of filters and 2d - # positional encoding. - self.latent_encoder = nn.Sequential( - nn.Conv2d( - in_channels=self.encoder.out_channels, - out_channels=self.hidden_dim, - kernel_size=1, - ), - PositionalEncoding2D( - hidden_dim=self.hidden_dim, - max_h=self.input_dims[1], - max_w=self.input_dims[2], - ), - nn.Flatten(start_dim=2), - ) - - # Token embedding. - self.token_embedding = nn.Embedding( - num_embeddings=self.num_classes, embedding_dim=self.hidden_dim - ) - - # Positional encoding for decoder tokens. - self.token_pos_encoder = PositionalEncoding( - hidden_dim=self.hidden_dim, dropout_rate=self.dropout_rate - ) - # Head - self.head = nn.Linear( - in_features=self.hidden_dim, out_features=self.num_classes - ) - - # Initalize weights for encoder. - self.init_weights() - - def init_weights(self) -> None: - """Initalize weights for decoder network and head.""" - bound = 0.1 - self.token_embedding.weight.data.uniform_(-bound, bound) - self.head.bias.data.zero_() - self.head.weight.data.uniform_(-bound, bound) - # TODO: Initalize encoder? - - def encode(self, x: Tensor) -> Tensor: - """Encodes an image into a latent feature vector. - - Args: - x (Tensor): Image tensor. - - Shape: - - x: :math: `(B, C, H, W)` - - z: :math: `(B, Sx, E)` - - where Sx is the length of the flattened feature maps projected from - the encoder. E latent dimension for each pixel in the projected - feature maps. - - Returns: - Tensor: A Latent embedding of the image. - """ - z = self.encoder(x) - z = self.latent_encoder(z) - - # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] - z = z.permute(0, 2, 1) - return z - - def decode(self, z: Tensor, context: Tensor) -> Tensor: - """Decodes latent images embedding into word pieces. - - Args: - z (Tensor): Latent images embedding. - context (Tensor): Word embeddings. - - Shapes: - - z: :math: `(B, Sx, E)` - - context: :math: `(B, Sy)` - - out: :math: `(B, Sy, T)` - - where Sy is the length of the output and T is the number of tokens. - - Returns: - Tensor: Sequence of word piece embeddings. - """ - context_mask = context != self.padding_idx - context = self.token_embedding(context) * math.sqrt(self.hidden_dim) - context = self.token_pos_encoder(context) - out = self.decoder(x=context, context=z, mask=context_mask) - logits = self.head(out) - return logits - - def forward(self, x: Tensor, context: Tensor) -> Tensor: - """Encodes images into word piece logtis. - - Args: - x (Tensor): Input image(s). - context (Tensor): Target word embeddings. - - Shapes: - - x: :math: `(B, C, H, W)` - - context: :math: `(B, Sy, T)` - - where B is the batch size, C is the number of input channels, H is - the image height and W is the image width. - - Returns: - Tensor: Sequence of logits. - """ - z = self.encode(x) - logits = self.decode(z, context) - return logits - - def predict(self, x: Tensor) -> Tensor: - """Predicts text in image. - - Args: - x (Tensor): Image(s) to extract text from. - - Shapes: - - x: :math: `(B, H, W)` - - output: :math: `(B, S)` - - Returns: - Tensor: A tensor of token indices of the predictions from the model. - """ - bsz = x.shape[0] - - # Encode image(s) to latent vectors. - z = self.encode(x) - - # Create a placeholder matrix for storing outputs from the network - output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) - output[:, 0] = self.start_index - - for i in range(1, self.max_output_len): - context = output[:, :i] # (bsz, i) - logits = self.decode(z, context) # (i, bsz, c) - tokens = torch.argmax(logits, dim=-1) # (i, bsz) - output[:, i : i + 1] = tokens[-1:] - - # Early stopping of prediction loop if token is end or padding token. - if ( - output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index - ).all(): - break - - # Set all tokens after end token to pad token. - for i in range(1, self.max_output_len): - idx = ( - output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index - ) - output[idx, i] = self.pad_index - - return output diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py new file mode 100644 index 0000000..4acdc36 --- /dev/null +++ b/text_recognizer/networks/conv_transformer.py @@ -0,0 +1,201 @@ +"""Vision transformer for character recognition.""" +import math +from typing import Tuple, Type + +import attr +import torch +from torch import nn, Tensor + +from text_recognizer.data.mappings import AbstractMapping +from text_recognizer.networks.base import BaseNetwork +from text_recognizer.networks.encoders.efficientnet import EfficientNet +from text_recognizer.networks.transformer.layers import Decoder +from text_recognizer.networks.transformer.positional_encodings import ( + PositionalEncoding, + PositionalEncoding2D, +) + + +@attr.s(auto_attribs=True) +class ConvTransformer(BaseNetwork): + # Parameters and placeholders, + input_dims: Tuple[int, int, int] = attr.ib() + hidden_dim: int = attr.ib() + dropout_rate: float = attr.ib() + max_output_len: int = attr.ib() + num_classes: int = attr.ib() + start_token: str = attr.ib() + start_index: Tensor = attr.ib(init=False) + end_token: str = attr.ib() + end_index: Tensor = attr.ib(init=False) + pad_token: str = attr.ib() + pad_index: Tensor = attr.ib(init=False) + + # Modules. + encoder: EfficientNet = attr.ib() + decoder: Decoder = attr.ib() + mapping: Type[AbstractMapping] = attr.ib() + + latent_encoder: nn.Sequential = attr.ib(init=False) + token_embedding: nn.Embedding = attr.ib(init=False) + token_pos_encoder: PositionalEncoding = attr.ib(init=False) + head: nn.Linear = attr.ib(init=False) + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" + self.start_index = self.mapping.get_index(self.start_token) + self.end_index = self.mapping.get_index(self.end_token) + self.pad_index = self.mapping.get_index(self.pad_token) + + # Latent projector for down sampling number of filters and 2d + # positional encoding. + self.latent_encoder = nn.Sequential( + nn.Conv2d( + in_channels=self.encoder.out_channels, + out_channels=self.hidden_dim, + kernel_size=1, + ), + PositionalEncoding2D( + hidden_dim=self.hidden_dim, + max_h=self.input_dims[1], + max_w=self.input_dims[2], + ), + nn.Flatten(start_dim=2), + ) + + # Token embedding. + self.token_embedding = nn.Embedding( + num_embeddings=self.num_classes, embedding_dim=self.hidden_dim + ) + + # Positional encoding for decoder tokens. + self.token_pos_encoder = PositionalEncoding( + hidden_dim=self.hidden_dim, dropout_rate=self.dropout_rate + ) + # Head + self.head = nn.Linear( + in_features=self.hidden_dim, out_features=self.num_classes + ) + + # Initalize weights for encoder. + self.init_weights() + + def init_weights(self) -> None: + """Initalize weights for decoder network and head.""" + bound = 0.1 + self.token_embedding.weight.data.uniform_(-bound, bound) + self.head.bias.data.zero_() + self.head.weight.data.uniform_(-bound, bound) + # TODO: Initalize encoder? + + def encode(self, x: Tensor) -> Tensor: + """Encodes an image into a latent feature vector. + + Args: + x (Tensor): Image tensor. + + Shape: + - x: :math: `(B, C, H, W)` + - z: :math: `(B, Sx, E)` + + where Sx is the length of the flattened feature maps projected from + the encoder. E latent dimension for each pixel in the projected + feature maps. + + Returns: + Tensor: A Latent embedding of the image. + """ + z = self.encoder(x) + z = self.latent_encoder(z) + + # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] + z = z.permute(0, 2, 1) + return z + + def decode(self, z: Tensor, context: Tensor) -> Tensor: + """Decodes latent images embedding into word pieces. + + Args: + z (Tensor): Latent images embedding. + context (Tensor): Word embeddings. + + Shapes: + - z: :math: `(B, Sx, E)` + - context: :math: `(B, Sy)` + - out: :math: `(B, Sy, T)` + + where Sy is the length of the output and T is the number of tokens. + + Returns: + Tensor: Sequence of word piece embeddings. + """ + context_mask = context != self.pad_index + context = self.token_embedding(context) * math.sqrt(self.hidden_dim) + context = self.token_pos_encoder(context) + out = self.decoder(x=context, context=z, mask=context_mask) + logits = self.head(out) + return logits + + def forward(self, x: Tensor, context: Tensor) -> Tensor: + """Encodes images into word piece logtis. + + Args: + x (Tensor): Input image(s). + context (Tensor): Target word embeddings. + + Shapes: + - x: :math: `(B, C, H, W)` + - context: :math: `(B, Sy, T)` + + where B is the batch size, C is the number of input channels, H is + the image height and W is the image width. + + Returns: + Tensor: Sequence of logits. + """ + z = self.encode(x) + logits = self.decode(z, context) + return logits + + def predict(self, x: Tensor) -> Tensor: + """Predicts text in image. + + Args: + x (Tensor): Image(s) to extract text from. + + Shapes: + - x: :math: `(B, H, W)` + - output: :math: `(B, S)` + + Returns: + Tensor: A tensor of token indices of the predictions from the model. + """ + bsz = x.shape[0] + + # Encode image(s) to latent vectors. + z = self.encode(x) + + # Create a placeholder matrix for storing outputs from the network + output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) + output[:, 0] = self.start_index + + for i in range(1, self.max_output_len): + context = output[:, :i] # (bsz, i) + logits = self.decode(z, context) # (i, bsz, c) + tokens = torch.argmax(logits, dim=-1) # (i, bsz) + output[:, i : i + 1] = tokens[-1:] + + # Early stopping of prediction loop if token is end or padding token. + if ( + output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index + ).all(): + break + + # Set all tokens after end token to pad token. + for i in range(1, self.max_output_len): + idx = ( + output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index + ) + output[idx, i] = self.pad_index + + return output diff --git a/training/conf/criterion/label_smoothing.yaml b/training/conf/criterion/label_smoothing.yaml index e69de29..ee47c59 100644 --- a/training/conf/criterion/label_smoothing.yaml +++ b/training/conf/criterion/label_smoothing.yaml @@ -0,0 +1,4 @@ +_target_: text_recognizer.criterion.label_smoothing +label_smoothing: 0.1 +vocab_size: 1006 +ignore_index: 1002 diff --git a/training/conf/mapping/word_piece.yaml b/training/conf/mapping/word_piece.yaml new file mode 100644 index 0000000..39e2ba4 --- /dev/null +++ b/training/conf/mapping/word_piece.yaml @@ -0,0 +1,9 @@ +_target_: text_recognizer.data.mappings.WordPieceMapping +num_features: 1000 +tokens: iamdb_1kwp_tokens_1000.txt +lexicon: iamdb_1kwp_lex_1000.txt +data_dir: null +use_words: false +prepend_wordsep: false +special_tokens: ["", "", "

"] +extra_symbols: ["\n"] diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml new file mode 100644 index 0000000..4e04b85 --- /dev/null +++ b/training/conf/model/lit_transformer.yaml @@ -0,0 +1,4 @@ +_target_: text_recognizer.models.transformer.TransformerLitModel +interval: null +monitor: val/loss +ignore_tokens: ["", "", "

"] diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml new file mode 100644 index 0000000..f72e030 --- /dev/null +++ b/training/conf/network/conv_transformer.yaml @@ -0,0 +1,13 @@ +defaults: + - encoder: efficientnet + - decoder: transformer_decoder + +_target_: text_recognizer.networks.conv_transformer.ConvTransformer +input_dims: [1, 576, 640] +hidden_dim: 256 +dropout_rate: 0.2 +max_output_len: 682 +num_classes: 1004 +start_token: +end_token: +pad_token:

-- cgit v1.2.3-70-g09d2