summaryrefslogtreecommitdiff
path: root/notebooks/05c-test-model-end-to-end.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'notebooks/05c-test-model-end-to-end.ipynb')
-rw-r--r--notebooks/05c-test-model-end-to-end.ipynb232
1 files changed, 103 insertions, 129 deletions
diff --git a/notebooks/05c-test-model-end-to-end.ipynb b/notebooks/05c-test-model-end-to-end.ipynb
index e2ccb3c..a96e484 100644
--- a/notebooks/05c-test-model-end-to-end.ipynb
+++ b/notebooks/05c-test-model-end-to-end.ipynb
@@ -48,71 +48,7 @@
{
"cell_type": "code",
"execution_count": 4,
- "id": "9c797159-845e-42c6-bd65-1c976ad627cd",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "encoder:\n",
- " _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet\n",
- " arch: b0\n",
- " out_channels: 1280\n",
- " stochastic_dropout_rate: 0.2\n",
- " bn_momentum: 0.99\n",
- " bn_eps: 0.001\n",
- "decoder:\n",
- " _target_: text_recognizer.networks.transformer.Decoder\n",
- " dim: 256\n",
- " depth: 2\n",
- " num_heads: 8\n",
- " attn_fn: text_recognizer.networks.transformer.attention.Attention\n",
- " attn_kwargs:\n",
- " num_heads: 8\n",
- " dim_head: 64\n",
- " dropout_rate: 0.2\n",
- " norm_fn: torch.nn.LayerNorm\n",
- " ff_fn: text_recognizer.networks.transformer.mlp.FeedForward\n",
- " ff_kwargs:\n",
- " dim: 256\n",
- " dim_out: null\n",
- " expansion_factor: 4\n",
- " glu: true\n",
- " dropout_rate: 0.2\n",
- " rotary_emb: null\n",
- " rotary_emb_dim: null\n",
- " cross_attend: true\n",
- " pre_norm: true\n",
- "_target_: text_recognizer.networks.conv_transformer.ConvTransformer\n",
- "input_dims:\n",
- "- 1\n",
- "- 576\n",
- "- 640\n",
- "hidden_dim: 256\n",
- "dropout_rate: 0.2\n",
- "max_output_len: 682\n",
- "num_classes: 1004\n",
- "start_token: <s>\n",
- "end_token: <e>\n",
- "pad_token: <p>\n",
- "\n",
- "{'encoder': {'_target_': 'text_recognizer.networks.encoders.efficientnet.EfficientNet', 'arch': 'b0', 'out_channels': 1280, 'stochastic_dropout_rate': 0.2, 'bn_momentum': 0.99, 'bn_eps': 0.001}, 'decoder': {'_target_': 'text_recognizer.networks.transformer.Decoder', 'dim': 256, 'depth': 2, 'num_heads': 8, 'attn_fn': 'text_recognizer.networks.transformer.attention.Attention', 'attn_kwargs': {'num_heads': 8, 'dim_head': 64, 'dropout_rate': 0.2}, 'norm_fn': 'torch.nn.LayerNorm', 'ff_fn': 'text_recognizer.networks.transformer.mlp.FeedForward', 'ff_kwargs': {'dim': 256, 'dim_out': None, 'expansion_factor': 4, 'glu': True, 'dropout_rate': 0.2}, 'rotary_emb': None, 'rotary_emb_dim': None, 'cross_attend': True, 'pre_norm': True}, '_target_': 'text_recognizer.networks.conv_transformer.ConvTransformer', 'input_dims': [1, 576, 640], 'hidden_dim': 256, 'dropout_rate': 0.2, 'max_output_len': 682, 'num_classes': 1004, 'start_token': '<s>', 'end_token': '<e>', 'pad_token': '<p>'}\n"
- ]
- }
- ],
- "source": [
- "# context initialization\n",
- "with initialize(config_path=\"../training/conf/network/\", job_name=\"test_app\"):\n",
- " cfg = compose(config_name=\"conv_transformer\")\n",
- " print(OmegaConf.to_yaml(cfg))\n",
- " print(cfg)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "id": "cdb895b6-8949-4318-8a40-06fb5ed5e8d6",
+ "id": "8741a844-3b97-47c4-a2a1-5a268d40923c",
"metadata": {},
"outputs": [
{
@@ -140,7 +76,8 @@
}
],
"source": [
- "with initialize(config_path=\"../training/conf/mapping/\", job_name=\"test_app\"):\n",
+ "# context initialization\n",
+ "with initialize(config_path=\"../training/conf/model/mapping\", job_name=\"test_app\"):\n",
" cfg = compose(config_name=\"word_piece\")\n",
" print(OmegaConf.to_yaml(cfg))\n",
" print(cfg)"
@@ -148,14 +85,32 @@
},
{
"cell_type": "code",
+ "execution_count": 5,
+ "id": "c9271d46-37b1-4d06-a603-46b5ed82f821",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2021-07-30 23:08:27.495 | DEBUG | text_recognizer.data.mappings:__attrs_post_init__:89 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n"
+ ]
+ }
+ ],
+ "source": [
+ "tt =instantiate(cfg)"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": 6,
- "id": "b6181656-580a-4d96-8495-b6bb510944cc",
+ "id": "bf1b07ac-9de7-4d24-a36b-09847bc6bc6f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "{'_target_': 'text_recognizer.data.mappings.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['<s>', '<e>', '<p>'], 'extra_symbols': ['\\n']}"
+ "WordPieceMapping(extra_symbols={'\\n'}, mapping=['<b>', '<s>', '<e>', '<p>', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', ' ', '!', '\"', '#', '&', \"'\", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '?', '\\n'], inverse_mapping={'<b>': 0, '<s>': 1, '<e>': 2, '<p>': 3, '0': 4, '1': 5, '2': 6, '3': 7, '4': 8, '5': 9, '6': 10, '7': 11, '8': 12, '9': 13, 'A': 14, 'B': 15, 'C': 16, 'D': 17, 'E': 18, 'F': 19, 'G': 20, 'H': 21, 'I': 22, 'J': 23, 'K': 24, 'L': 25, 'M': 26, 'N': 27, 'O': 28, 'P': 29, 'Q': 30, 'R': 31, 'S': 32, 'T': 33, 'U': 34, 'V': 35, 'W': 36, 'X': 37, 'Y': 38, 'Z': 39, 'a': 40, 'b': 41, 'c': 42, 'd': 43, 'e': 44, 'f': 45, 'g': 46, 'h': 47, 'i': 48, 'j': 49, 'k': 50, 'l': 51, 'm': 52, 'n': 53, 'o': 54, 'p': 55, 'q': 56, 'r': 57, 's': 58, 't': 59, 'u': 60, 'v': 61, 'w': 62, 'x': 63, 'y': 64, 'z': 65, ' ': 66, '!': 67, '\"': 68, '#': 69, '&': 70, \"'\": 71, '(': 72, ')': 73, '*': 74, '+': 75, ',': 76, '-': 77, '.': 78, '/': 79, ':': 80, ';': 81, '?': 82, '\\n': 83}, input_size=[28, 28], data_dir=PosixPath('/home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb'), num_features=1000, tokens='iamdb_1kwp_tokens_1000.txt', lexicon='iamdb_1kwp_lex_1000.txt', use_words=False, prepend_wordsep=False, special_tokens={'<p>', '<s>', '<e>'}, wordpiece_processor=<text_recognizer.data.iam_preprocessor.Preprocessor object at 0x7fa4ec7ea610>)"
]
},
"execution_count": 6,
@@ -164,102 +119,121 @@
}
],
"source": [
- "cfg"
+ "tt"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "5cd80d84-3ae5-4bb4-bc00-0dac7b22e134",
+ "id": "2452e8f4-cc5f-4763-9a25-4fa27b7f143e",
"metadata": {},
"outputs": [],
- "source": []
+ "source": [
+ "tt.mapping"
+ ]
},
{
"cell_type": "code",
- "execution_count": 8,
- "id": "0c123c76-ed90-49fa-903b-70ad60a33f16",
+ "execution_count": null,
+ "id": "6b722ca0-9c65-4f90-be4e-b7334ea81237",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "2021-07-29 23:02:56.650 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "mapping = instantiate(cfg)"
+ "# context initialization\n",
+ "with initialize(config_path=\"../training/conf/model/\", job_name=\"test_app\"):\n",
+ " cfg = compose(config_name=\"lit_transformer\")\n",
+ " print(OmegaConf.to_yaml(cfg))\n",
+ " print(cfg)"
]
},
{
"cell_type": "code",
- "execution_count": 9,
- "id": "ff6c57f0-3c96-418e-8192-cd12bf79c073",
+ "execution_count": null,
+ "id": "9c797159-845e-42c6-bd65-1c976ad627cd",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([1002])"
- ]
- },
- "execution_count": 9,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
+ "source": [
+ "# context initialization\n",
+ "with initialize(config_path=\"../training/conf/network/\", job_name=\"test_app\"):\n",
+ " cfg = compose(config_name=\"conv_transformer\")\n",
+ " print(OmegaConf.to_yaml(cfg))\n",
+ " print(cfg)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "dcfbe2ab-6775-4aa4-acf4-57203a3f5511",
+ "metadata": {},
+ "outputs": [],
"source": [
- "mapping.get_index(\"<p>\")"
+ "from importlib import import_module"
]
},
{
"cell_type": "code",
- "execution_count": 10,
- "id": "348391ec-0cf7-49f6-bac2-26bc8c966705",
+ "execution_count": null,
+ "id": "e3d4c70e-509d-457a-ac81-2bac27cb95d2",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "1006"
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "len(mapping)"
+ "x = import_module(\"text_recognizer.networks.transformer.attention\")"
]
},
{
"cell_type": "code",
- "execution_count": 15,
- "id": "67673bf2-79c6-4010-93dd-9c9ba8f9a90e",
+ "execution_count": null,
+ "id": "039d4a7f-f70d-43a1-8b5f-7e766ac01010",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([1003])"
- ]
- },
- "execution_count": 15,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
+ "source": [
+ "y = partial(getattr(x, \"Attention\"), dim=16, num_heads=2, **cfg.decoder.attn_kwargs)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9be1d661-bfac-4826-ab8d-453557713f68",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "y().causal"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "54b35e6f-35db-4769-8bc5-ed1764768cf2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "y(causal=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "af2c8cfa-0b45-4681-b671-0f97ace62516",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "net = instantiate(cfg)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8f0742ad-5e2f-42d5-83e7-6e46398b4f0f",
+ "metadata": {},
+ "outputs": [],
"source": [
- "mapping.get_index(\"\\n\")"
+ "net"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "8923ea1e-b571-42ee-bfd7-4984aa70644f",
+ "id": "709be6cc-6708-4561-ad45-28f433612a0d",
"metadata": {},
"outputs": [],
"source": []