diff options
Diffstat (limited to 'notebooks/05c-test-model-end-to-end.ipynb')
-rw-r--r-- | notebooks/05c-test-model-end-to-end.ipynb | 340 |
1 files changed, 116 insertions, 224 deletions
diff --git a/notebooks/05c-test-model-end-to-end.ipynb b/notebooks/05c-test-model-end-to-end.ipynb index a0b4ee9..e2ccb3c 100644 --- a/notebooks/05c-test-model-end-to-end.ipynb +++ b/notebooks/05c-test-model-end-to-end.ipynb @@ -19,43 +19,13 @@ "from importlib.util import find_spec\n", "if find_spec(\"text_recognizer\") is None:\n", " import sys\n", - " sys.path.append('..')" + " sys.path.append('..')\n", + " " ] }, { "cell_type": "code", "execution_count": 2, - "id": "2ab9ac7a-a288-45bc-bfb7-8579a3a38d93", - "metadata": {}, - "outputs": [], - "source": [ - "import torch.nn.functional as F" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "ecab65ba-5aa0-45f0-99d7-e837464185ac", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "<function torch.nn.functional.softmax(input: torch.Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> torch.Tensor>" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "F.softmax" - ] - }, - { - "cell_type": "code", - "execution_count": 5, "id": "3e812a1e", "metadata": {}, "outputs": [], @@ -65,309 +35,231 @@ }, { "cell_type": "code", - "execution_count": 10, - "id": "a42a7988", + "execution_count": 3, + "id": "d3a6146b-94b1-4618-a4e4-00f8e23ffdb0", "metadata": {}, "outputs": [], "source": [ - "@attr.s\n", - "class C(object):\n", - " d = {2: \"hej\"}\n", - " x: F.softmax = attr.ib(init=False, default=F.softmax)\n", - " @x.validator\n", - " def check(self, attribute, value):\n", - " print(attribute)\n", - " print(self.x)" + "from hydra import compose, initialize\n", + "from omegaconf import OmegaConf\n", + "from hydra.utils import instantiate" ] }, { "cell_type": "code", - "execution_count": 14, - "id": "660a7b1f", + "execution_count": 4, + "id": "9c797159-845e-42c6-bd65-1c976ad627cd", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Attribute(name='x', default=<function softmax at 0x7fb624839ca0>, validator=<function C.check at 0x7fb622ce2040>, repr=True, eq=True, eq_key=None, order=True, order_key=None, hash=None, init=False, metadata=mappingproxy({}), type=<function softmax at 0x7fb624839ca0>, converter=None, kw_only=False, inherited=False, on_setattr=None)\n", - "<function softmax at 0x7fb624839ca0>\n" + "encoder:\n", + " _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet\n", + " arch: b0\n", + " out_channels: 1280\n", + " stochastic_dropout_rate: 0.2\n", + " bn_momentum: 0.99\n", + " bn_eps: 0.001\n", + "decoder:\n", + " _target_: text_recognizer.networks.transformer.Decoder\n", + " dim: 256\n", + " depth: 2\n", + " num_heads: 8\n", + " attn_fn: text_recognizer.networks.transformer.attention.Attention\n", + " attn_kwargs:\n", + " num_heads: 8\n", + " dim_head: 64\n", + " dropout_rate: 0.2\n", + " norm_fn: torch.nn.LayerNorm\n", + " ff_fn: text_recognizer.networks.transformer.mlp.FeedForward\n", + " ff_kwargs:\n", + " dim: 256\n", + " dim_out: null\n", + " expansion_factor: 4\n", + " glu: true\n", + " dropout_rate: 0.2\n", + " rotary_emb: null\n", + " rotary_emb_dim: null\n", + " cross_attend: true\n", + " pre_norm: true\n", + "_target_: text_recognizer.networks.conv_transformer.ConvTransformer\n", + "input_dims:\n", + "- 1\n", + "- 576\n", + "- 640\n", + "hidden_dim: 256\n", + "dropout_rate: 0.2\n", + "max_output_len: 682\n", + "num_classes: 1004\n", + "start_token: <s>\n", + "end_token: <e>\n", + "pad_token: <p>\n", + "\n", + "{'encoder': {'_target_': 'text_recognizer.networks.encoders.efficientnet.EfficientNet', 'arch': 'b0', 'out_channels': 1280, 'stochastic_dropout_rate': 0.2, 'bn_momentum': 0.99, 'bn_eps': 0.001}, 'decoder': {'_target_': 'text_recognizer.networks.transformer.Decoder', 'dim': 256, 'depth': 2, 'num_heads': 8, 'attn_fn': 'text_recognizer.networks.transformer.attention.Attention', 'attn_kwargs': {'num_heads': 8, 'dim_head': 64, 'dropout_rate': 0.2}, 'norm_fn': 'torch.nn.LayerNorm', 'ff_fn': 'text_recognizer.networks.transformer.mlp.FeedForward', 'ff_kwargs': {'dim': 256, 'dim_out': None, 'expansion_factor': 4, 'glu': True, 'dropout_rate': 0.2}, 'rotary_emb': None, 'rotary_emb_dim': None, 'cross_attend': True, 'pre_norm': True}, '_target_': 'text_recognizer.networks.conv_transformer.ConvTransformer', 'input_dims': [1, 576, 640], 'hidden_dim': 256, 'dropout_rate': 0.2, 'max_output_len': 682, 'num_classes': 1004, 'start_token': '<s>', 'end_token': '<e>', 'pad_token': '<p>'}\n" ] } ], "source": [ - "c = C()" + "# context initialization\n", + "with initialize(config_path=\"../training/conf/network/\", job_name=\"test_app\"):\n", + " cfg = compose(config_name=\"conv_transformer\")\n", + " print(OmegaConf.to_yaml(cfg))\n", + " print(cfg)" ] }, { "cell_type": "code", - "execution_count": 12, - "id": "9c3d1163", + "execution_count": 5, + "id": "cdb895b6-8949-4318-8a40-06fb5ed5e8d6", "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "<function torch.nn.functional.softmax(input: torch.Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[int] = None) -> torch.Tensor>" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "_target_: text_recognizer.data.mappings.WordPieceMapping\n", + "num_features: 1000\n", + "tokens: iamdb_1kwp_tokens_1000.txt\n", + "lexicon: iamdb_1kwp_lex_1000.txt\n", + "data_dir: null\n", + "use_words: false\n", + "prepend_wordsep: false\n", + "special_tokens:\n", + "- <s>\n", + "- <e>\n", + "- <p>\n", + "extra_symbols:\n", + "- '\n", + "\n", + " '\n", + "\n", + "{'_target_': 'text_recognizer.data.mappings.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['<s>', '<e>', '<p>'], 'extra_symbols': ['\\n']}\n" + ] } ], "source": [ - "c.x" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "id": "b3c8879c", - "metadata": {}, - "outputs": [], - "source": [ - "from torch import nn" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "2f5f6b75", - "metadata": {}, - "outputs": [], - "source": [ - "l = nn.ModuleList([])" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "9938ec53", - "metadata": {}, - "outputs": [], - "source": [ - "f = nn.Linear(10, 10)" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "id": "fc49db78", - "metadata": {}, - "outputs": [], - "source": [ - "for _ in range(10):\n", - " l.append(f)" + "with initialize(config_path=\"../training/conf/mapping/\", job_name=\"test_app\"):\n", + " cfg = compose(config_name=\"word_piece\")\n", + " print(OmegaConf.to_yaml(cfg))\n", + " print(cfg)" ] }, { "cell_type": "code", - "execution_count": 36, - "id": "e799a9dc", + "execution_count": 6, + "id": "b6181656-580a-4d96-8495-b6bb510944cc", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "ModuleList(\n", - " (0): Linear(in_features=10, out_features=10, bias=True)\n", - " (1): Linear(in_features=10, out_features=10, bias=True)\n", - " (2): Linear(in_features=10, out_features=10, bias=True)\n", - " (3): Linear(in_features=10, out_features=10, bias=True)\n", - " (4): Linear(in_features=10, out_features=10, bias=True)\n", - " (5): Linear(in_features=10, out_features=10, bias=True)\n", - " (6): Linear(in_features=10, out_features=10, bias=True)\n", - " (7): Linear(in_features=10, out_features=10, bias=True)\n", - " (8): Linear(in_features=10, out_features=10, bias=True)\n", - " (9): Linear(in_features=10, out_features=10, bias=True)\n", - ")" + "{'_target_': 'text_recognizer.data.mappings.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['<s>', '<e>', '<p>'], 'extra_symbols': ['\\n']}" ] }, - "execution_count": 36, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "\n", - "l" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "id": "17213dfb", - "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "'Linear' object has no attribute 'copy'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_31696/2302067867.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mff\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m~/.cache/pypoetry/virtualenvs/text-recognizer-ejNaVa9M-py3.9/lib/python3.9/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 1128\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1129\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmodules\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1130\u001b[0;31m raise AttributeError(\"'{}' object has no attribute '{}'\".format(\n\u001b[0m\u001b[1;32m 1131\u001b[0m type(self).__name__, name))\n\u001b[1;32m 1132\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAttributeError\u001b[0m: 'Linear' object has no attribute 'copy'" - ] - } - ], - "source": [ - "ff = f.copy()" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "id": "60277c26", - "metadata": {}, - "outputs": [], - "source": [ - "from copy import deepcopy" + "cfg" ] }, { "cell_type": "code", - "execution_count": 39, - "id": "cf86534a", + "execution_count": null, + "id": "5cd80d84-3ae5-4bb4-bc00-0dac7b22e134", "metadata": {}, "outputs": [], - "source": [ - "ff = deepcopy(f)" - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "id": "2a260dc8", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "140011688939472" - ] - }, - "execution_count": 43, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "id(ff)" - ] + "source": [] }, { "cell_type": "code", - "execution_count": 42, - "id": "6dcf5f63", + "execution_count": 8, + "id": "0c123c76-ed90-49fa-903b-70ad60a33f16", "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "140011688936544" - ] - }, - "execution_count": 42, - "metadata": {}, - "output_type": "execute_result" + "name": "stderr", + "output_type": "stream", + "text": [ + "2021-07-29 23:02:56.650 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n" + ] } ], "source": [ - "id(f)" + "mapping = instantiate(cfg)" ] }, { "cell_type": "code", - "execution_count": 44, - "id": "74958f8d", + "execution_count": 9, + "id": "ff6c57f0-3c96-418e-8192-cd12bf79c073", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "140011688936544" + "tensor([1002])" ] }, - "execution_count": 44, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "id(l[0])" + "mapping.get_index(\"<p>\")" ] }, { "cell_type": "code", - "execution_count": 45, - "id": "bcceabd5", + "execution_count": 10, + "id": "348391ec-0cf7-49f6-bac2-26bc8c966705", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "140011688936544" + "1006" ] }, - "execution_count": 45, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "id(l[1])" + "len(mapping)" ] }, { "cell_type": "code", - "execution_count": 58, - "id": "191a0b03", + "execution_count": 15, + "id": "67673bf2-79c6-4010-93dd-9c9ba8f9a90e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'nn'" + "tensor([1003])" ] }, - "execution_count": 58, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "\".\".join(\"nn.LayerNorm\".split(\".\")[:-1])" - ] - }, - { - "cell_type": "code", - "execution_count": 60, - "id": "4ff8ae08", - "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "'str' object has no attribute 'LayerNorm'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/tmp/ipykernel_31696/162121485.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"torch.nn\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"LayerNorm\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m: 'str' object has no attribute 'LayerNorm'" - ] - } - ], - "source": [ - "getattr(\"torch.nn\", \"LayerNorm\")" + "mapping.get_index(\"\\n\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "4d536bf2", + "id": "8923ea1e-b571-42ee-bfd7-4984aa70644f", "metadata": {}, "outputs": [], "source": [] |