summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-05 23:05:25 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-05 23:05:25 +0200
commit4d1f2cef39688871d2caafce42a09316381a27ae (patch)
tree0f4385969e7df6d7d313cd5910bde9a7475ca027
parentf0481decdad9afb52494e9e95996deef843ef233 (diff)
Refactor with attr, working on cnn+transformer network
-rw-r--r--notebooks/00-scratch-pad.ipynb644
-rw-r--r--notebooks/03-look-at-iam-paragraphs.ipynb43
-rw-r--r--text_recognizer/callbacks/__init__.py1
-rw-r--r--text_recognizer/callbacks/wandb_callbacks.py8
-rw-r--r--text_recognizer/criterions/__init__.py1
-rw-r--r--text_recognizer/criterions/label_smoothing_loss.py (renamed from text_recognizer/networks/loss/label_smoothing_loss.py)0
-rw-r--r--text_recognizer/data/base_data_module.py14
-rw-r--r--text_recognizer/data/base_dataset.py24
-rw-r--r--text_recognizer/models/__init__.py2
-rw-r--r--text_recognizer/models/base.py11
-rw-r--r--text_recognizer/models/transformer.py30
-rw-r--r--text_recognizer/models/vqvae.py6
-rw-r--r--text_recognizer/networks/cnn_tranformer.py14
-rw-r--r--text_recognizer/networks/loss/__init__.py2
-rw-r--r--text_recognizer/networks/util.py7
15 files changed, 690 insertions, 117 deletions
diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb
index 2ade2bb..16c6533 100644
--- a/notebooks/00-scratch-pad.ipynb
+++ b/notebooks/00-scratch-pad.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -30,106 +30,244 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
- "from text_recognizer.networks.encoders.efficientnet.efficientnet import EfficientNet"
+ "from pathlib import Path"
]
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {
- "scrolled": false
- },
+ "execution_count": 2,
+ "metadata": {},
"outputs": [],
"source": [
- "en = EfficientNet(\"b0\")"
+ "import attr"
]
},
{
"cell_type": "code",
- "execution_count": null,
- "metadata": {
- "scrolled": false
- },
+ "execution_count": 9,
+ "metadata": {},
"outputs": [],
"source": [
- "summary(en, (1, 224, 224));"
+ "@attr.s\n",
+ "class B:\n",
+ " batch_size = attr.ib()\n",
+ " num_workers = attr.ib()"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
- "torch.cuda.is_available()"
+ "@attr.s\n",
+ "class T(B):\n",
+ "\n",
+ " def __attrs_post_init__(self) -> None:\n",
+ " super().__init__(self.batch_size, self.num_workers)\n",
+ " self.hej = None\n",
+ " \n",
+ " batch_size = attr.ib()\n",
+ " num_workers = attr.ib()\n",
+ " h: Path = attr.ib(converter=Path)"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
- "decoder = Decoder(dim=128, depth=2, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)"
+ "t = T(batch_size=16, num_workers=2, h=\"hej\")"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 13,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "PosixPath('hej')"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
- "decoder.cuda()"
+ "t.h"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 12,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "16"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
- "transformer_decoder = Transformer(num_tokens=1000, max_seq_len=690, attn_layers=decoder, emb_dim=128, emb_dropout=0.1)"
+ "t.batch_size"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 11,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "PosixPath('hej')"
+ ]
+ },
+ "execution_count": 11,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
"source": [
- "transformer_decoder.cuda()"
+ "t.h"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "../text_recognizer/__init__.py\n",
+ "../text_recognizer/callbacks/__init__.py\n",
+ "../text_recognizer/callbacks/wandb_callbacks.py\n",
+ "../text_recognizer/data/image_utils.py\n",
+ "../text_recognizer/data/emnist.py\n",
+ "../text_recognizer/data/iam_lines.py\n",
+ "../text_recognizer/data/download_utils.py\n",
+ "../text_recognizer/data/mappings.py\n",
+ "../text_recognizer/data/iam_preprocessor.py\n",
+ "../text_recognizer/data/__init__.py\n",
+ "../text_recognizer/data/make_wordpieces.py\n",
+ "../text_recognizer/data/iam_paragraphs.py\n",
+ "../text_recognizer/data/sentence_generator.py\n",
+ "../text_recognizer/data/emnist_lines.py\n",
+ "../text_recognizer/data/build_transitions.py\n",
+ "../text_recognizer/data/base_dataset.py\n",
+ "../text_recognizer/data/base_data_module.py\n",
+ "../text_recognizer/data/iam.py\n",
+ "../text_recognizer/data/iam_synthetic_paragraphs.py\n",
+ "../text_recognizer/data/transforms.py\n",
+ "../text_recognizer/data/iam_extended_paragraphs.py\n",
+ "../text_recognizer/networks/__init__.py\n",
+ "../text_recognizer/networks/util.py\n",
+ "../text_recognizer/networks/cnn_tranformer.py\n",
+ "../text_recognizer/networks/encoders/__init__.py\n",
+ "../text_recognizer/networks/encoders/efficientnet/efficientnet.py\n",
+ "../text_recognizer/networks/encoders/efficientnet/__init__.py\n",
+ "../text_recognizer/networks/encoders/efficientnet/utils.py\n",
+ "../text_recognizer/networks/encoders/efficientnet/mbconv.py\n",
+ "../text_recognizer/networks/loss/__init__.py\n",
+ "../text_recognizer/networks/loss/label_smoothing_loss.py\n",
+ "../text_recognizer/networks/vqvae/__init__.py\n",
+ "../text_recognizer/networks/vqvae/decoder.py\n",
+ "../text_recognizer/networks/vqvae/vqvae.py\n",
+ "../text_recognizer/networks/vqvae/vector_quantizer.py\n",
+ "../text_recognizer/networks/vqvae/encoder.py\n",
+ "../text_recognizer/networks/transformer/__init__.py\n",
+ "../text_recognizer/networks/transformer/layers.py\n",
+ "../text_recognizer/networks/transformer/residual.py\n",
+ "../text_recognizer/networks/transformer/attention.py\n",
+ "../text_recognizer/networks/transformer/transformer.py\n",
+ "../text_recognizer/networks/transformer/vit.py\n",
+ "../text_recognizer/networks/transformer/mlp.py\n",
+ "../text_recognizer/networks/transformer/norm.py\n",
+ "../text_recognizer/networks/transformer/positional_encodings/positional_encoding.py\n",
+ "../text_recognizer/networks/transformer/positional_encodings/__init__.py\n",
+ "../text_recognizer/networks/transformer/positional_encodings/absolute_embedding.py\n",
+ "../text_recognizer/networks/transformer/positional_encodings/rotary_embedding.py\n",
+ "../text_recognizer/networks/transformer/nystromer/__init__.py\n",
+ "../text_recognizer/networks/transformer/nystromer/nystromer.py\n",
+ "../text_recognizer/networks/transformer/nystromer/attention.py\n",
+ "../text_recognizer/models/__init__.py\n",
+ "../text_recognizer/models/base.py\n",
+ "../text_recognizer/models/vqvae.py\n",
+ "../text_recognizer/models/transformer.py\n",
+ "../text_recognizer/models/dino.py\n",
+ "../text_recognizer/models/metrics.py\n"
+ ]
+ }
+ ],
+ "source": [
+ "for f in Path(\"../text_recognizer\").glob(\"**/*.py\"):\n",
+ " print(f)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "<generator object Path.glob at 0x7ff8bb9ce5f0>"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "Path(\"..\").glob(\"**/*.py\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
- "efficient_transformer = Nystromer(\n",
- " dim = 64,\n",
- " depth = 4,\n",
- " num_heads = 8,\n",
- " num_landmarks = 64\n",
- ")"
+ "from text_recognizer.networks.encoders.efficientnet.efficientnet import EfficientNet"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {
+ "scrolled": false
+ },
+ "outputs": [],
+ "source": [
+ "en = EfficientNet(\"b0\")"
]
},
{
"cell_type": "code",
"execution_count": null,
- "metadata": {},
+ "metadata": {
+ "scrolled": false
+ },
"outputs": [],
"source": [
- "v = ViT(\n",
- " dim = 64,\n",
- " image_size = (576, 640),\n",
- " patch_size = (32, 32),\n",
- " transformer = efficient_transformer\n",
- ").cuda()"
+ "summary(en, (1, 224, 224));"
]
},
{
@@ -138,7 +276,7 @@
"metadata": {},
"outputs": [],
"source": [
- "t = torch.randn(8, 1, 576, 640).cuda()"
+ "torch.cuda.is_available()"
]
},
{
@@ -147,7 +285,7 @@
"metadata": {},
"outputs": [],
"source": [
- "en.cuda()"
+ "decoder = Decoder(dim=128, depth=2, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)"
]
},
{
@@ -156,7 +294,7 @@
"metadata": {},
"outputs": [],
"source": [
- "en(t)"
+ "decoder.cuda()"
]
},
{
@@ -165,7 +303,7 @@
"metadata": {},
"outputs": [],
"source": [
- "o = v(t)"
+ "transformer_decoder = Transformer(num_tokens=1003, max_seq_len=451, attn_layers=decoder, emb_dim=128, emb_dropout=0.1)"
]
},
{
@@ -174,7 +312,7 @@
"metadata": {},
"outputs": [],
"source": [
- "caption = torch.randint(0, 90, (16, 690)).cuda()"
+ "transformer_decoder.cuda()"
]
},
{
@@ -183,7 +321,12 @@
"metadata": {},
"outputs": [],
"source": [
- "o.shape"
+ "efficient_transformer = Nystromer(\n",
+ " dim = 64,\n",
+ " depth = 4,\n",
+ " num_heads = 8,\n",
+ " num_landmarks = 64\n",
+ ")"
]
},
{
@@ -192,16 +335,405 @@
"metadata": {},
"outputs": [],
"source": [
- "caption.shape"
+ "v = ViT(\n",
+ " dim = 64,\n",
+ " image_size = (576, 640),\n",
+ " patch_size = (32, 32),\n",
+ " transformer = efficient_transformer\n",
+ ").cuda()"
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
- "o = torch.randn(16, 20 * 18, 128).cuda()"
+ "t = torch.randn(4, 1, 576, 640).cuda()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "EfficientNet(\n",
+ " (_conv_stem): Sequential(\n",
+ " (0): ZeroPad2d(padding=(0, 1, 0, 1), value=0.0)\n",
+ " (1): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)\n",
+ " (2): BatchNorm2d(32, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (3): Mish(inplace=True)\n",
+ " )\n",
+ " (_blocks): ModuleList(\n",
+ " (0): MBConvBlock(\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)\n",
+ " (1): BatchNorm2d(32, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(16, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (1): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(16, 96, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(96, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(96, 96, kernel_size=(3, 3), stride=(2, 2), groups=96, bias=False)\n",
+ " (1): BatchNorm2d(96, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(24, 96, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(96, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(24, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (2): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), groups=144, bias=False)\n",
+ " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(144, 36, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(36, 144, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(144, 24, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(24, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (3): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(144, 144, kernel_size=(5, 5), stride=(2, 2), groups=144, bias=False)\n",
+ " (1): BatchNorm2d(144, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(144, 36, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(36, 144, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(144, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (4): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(240, 240, kernel_size=(5, 5), stride=(1, 1), groups=240, bias=False)\n",
+ " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(240, 60, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(60, 240, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(240, 40, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(40, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (5): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(40, 240, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(240, 240, kernel_size=(3, 3), stride=(2, 2), groups=240, bias=False)\n",
+ " (1): BatchNorm2d(240, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(240, 60, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(60, 240, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(240, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (6): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), groups=480, bias=False)\n",
+ " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(480, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (7): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(480, 480, kernel_size=(3, 3), stride=(1, 1), groups=480, bias=False)\n",
+ " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(480, 80, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(80, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (8): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(80, 480, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(480, 480, kernel_size=(5, 5), stride=(1, 1), groups=480, bias=False)\n",
+ " (1): BatchNorm2d(480, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(480, 120, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(120, 480, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(480, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (9): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(1, 1), groups=672, bias=False)\n",
+ " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (10): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(1, 1), groups=672, bias=False)\n",
+ " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(672, 112, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(112, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (11): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(112, 672, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(672, 672, kernel_size=(5, 5), stride=(2, 2), groups=672, bias=False)\n",
+ " (1): BatchNorm2d(672, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(672, 168, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(168, 672, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(672, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (12): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n",
+ " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (13): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n",
+ " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (14): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(1152, 1152, kernel_size=(5, 5), stride=(1, 1), groups=1152, bias=False)\n",
+ " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(1152, 192, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(192, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " (15): MBConvBlock(\n",
+ " (_inverted_bottleneck): Sequential(\n",
+ " (0): Conv2d(192, 1152, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_depthwise): Sequential(\n",
+ " (0): Conv2d(1152, 1152, kernel_size=(3, 3), stride=(1, 1), groups=1152, bias=False)\n",
+ " (1): BatchNorm2d(1152, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " (2): Mish(inplace=True)\n",
+ " )\n",
+ " (_squeeze_excite): Sequential(\n",
+ " (0): Conv2d(1152, 288, kernel_size=(1, 1), stride=(1, 1))\n",
+ " (1): Mish(inplace=True)\n",
+ " (2): Conv2d(288, 1152, kernel_size=(1, 1), stride=(1, 1))\n",
+ " )\n",
+ " (_pointwise): Sequential(\n",
+ " (0): Conv2d(1152, 320, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(320, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (_conv_head): Sequential(\n",
+ " (0): Conv2d(320, 1280, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
+ " (1): BatchNorm2d(1280, eps=0.001, momentum=0.99, affine=True, track_running_stats=True)\n",
+ " )\n",
+ ")"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "en.cuda()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([4, 1280, 18, 20])"
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "en(t).shape"
]
},
{
@@ -210,7 +742,7 @@
"metadata": {},
"outputs": [],
"source": [
- "caption = torch.randint(0, 1000, (16, 200)).cuda()"
+ "o = v(t)"
]
},
{
@@ -219,7 +751,7 @@
"metadata": {},
"outputs": [],
"source": [
- "transformer_decoder(caption, context = o).shape # (1, 1024, 20000)"
+ "caption = torch.randint(0, 90, (16, 690)).cuda()"
]
},
{
@@ -228,7 +760,7 @@
"metadata": {},
"outputs": [],
"source": [
- "from text_recognizer.networks.encoders.efficientnet.efficientnet import EfficientNet"
+ "o.shape"
]
},
{
@@ -237,7 +769,7 @@
"metadata": {},
"outputs": [],
"source": [
- "en = EfficientNet()"
+ "caption.shape"
]
},
{
@@ -246,7 +778,7 @@
"metadata": {},
"outputs": [],
"source": [
- "en.cuda()"
+ "o = torch.randn(16, 20 * 18, 128).cuda()"
]
},
{
@@ -255,7 +787,7 @@
"metadata": {},
"outputs": [],
"source": [
- "summary(en, (1, 576, 640))"
+ "caption = torch.randint(0, 1000, (16, 200)).cuda()"
]
},
{
@@ -264,7 +796,7 @@
"metadata": {},
"outputs": [],
"source": [
- "type(efficient_transformer)"
+ "transformer_decoder(caption, context = o).shape # (1, 1024, 20000)"
]
},
{
diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb
index 37fef04..315b7bf 100644
--- a/notebooks/03-look-at-iam-paragraphs.ipynb
+++ b/notebooks/03-look-at-iam-paragraphs.ipynb
@@ -40,7 +40,7 @@
},
{
"cell_type": "code",
- "execution_count": 54,
+ "execution_count": 2,
"id": "726ac25b",
"metadata": {},
"outputs": [],
@@ -57,7 +57,7 @@
},
{
"cell_type": "code",
- "execution_count": 64,
+ "execution_count": 3,
"id": "c6188bce",
"metadata": {
"scrolled": true
@@ -67,13 +67,13 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2021-06-27 20:17:40.498 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n",
- "2021-06-27 20:17:40.682 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n",
- "2021-06-27 20:17:40.777 | INFO | text_recognizer.data.iam_paragraphs:setup:111 - Loading IAM paragraph regions and lines for None...\n",
- "2021-06-27 20:17:54.542 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n",
- "2021-06-27 20:17:56.911 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n",
- "2021-06-27 20:17:57.147 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:75 - IAM Synthetic dataset steup for stage None...\n",
- "2021-06-27 20:18:07.707 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n"
+ "2021-06-27 20:59:27.366 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n",
+ "2021-06-27 20:59:27.464 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n",
+ "2021-06-27 20:59:27.559 | INFO | text_recognizer.data.iam_paragraphs:setup:111 - Loading IAM paragraph regions and lines for None...\n",
+ "2021-06-27 20:59:40.932 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n",
+ "2021-06-27 20:59:43.173 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n",
+ "2021-06-27 20:59:43.267 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:75 - IAM Synthetic dataset steup for stage None...\n",
+ "2021-06-27 20:59:53.470 | DEBUG | text_recognizer.data.mappings:_configure_wordpiece_processor:104 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n"
]
},
{
@@ -84,8 +84,8 @@
"Num classes: 1006\n",
"Dims: (1, 576, 640)\n",
"Output dims: (682, 1)\n",
- "Train/val/test sizes: 19907, 262, 231\n",
- "Train Batch x stats: (torch.Size([8, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0105), tensor(0.0575), tensor(1.))\n",
+ "Train/val/test sizes: 19957, 262, 231\n",
+ "Train Batch x stats: (torch.Size([8, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0111), tensor(0.0604), tensor(1.))\n",
"Train Batch y stats: (torch.Size([8, 451]), torch.int64, tensor(1), tensor(1004))\n",
"Test Batch x stats: (torch.Size([8, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0315), tensor(0.0799), tensor(0.9098))\n",
"Test Batch y stats: (torch.Size([8, 451]), torch.int64, tensor(1), tensor(1003))\n",
@@ -102,6 +102,27 @@
},
{
"cell_type": "code",
+ "execution_count": 4,
+ "id": "55b26b5d",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "1006"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "len(dataset.mapping)"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": null,
"id": "42501428",
"metadata": {},
diff --git a/text_recognizer/callbacks/__init__.py b/text_recognizer/callbacks/__init__.py
index e69de29..82d8ce3 100644
--- a/text_recognizer/callbacks/__init__.py
+++ b/text_recognizer/callbacks/__init__.py
@@ -0,0 +1 @@
+"""Module for PyTorch Lightning callbacks."""
diff --git a/text_recognizer/callbacks/wandb_callbacks.py b/text_recognizer/callbacks/wandb_callbacks.py
index 900c3b1..4186b4a 100644
--- a/text_recognizer/callbacks/wandb_callbacks.py
+++ b/text_recognizer/callbacks/wandb_callbacks.py
@@ -29,7 +29,7 @@ class WatchModel(Callback):
log: str = attr.ib(default="gradients")
log_freq: int = attr.ib(default=100)
- def __attrs_pre_init__(self):
+ def __attrs_pre_init__(self) -> None:
super().__init__()
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
@@ -44,7 +44,7 @@ class UploadCodeAsArtifact(Callback):
project_dir: Path = attr.ib(converter=Path)
- def __attrs_pre_init__(self):
+ def __attrs_pre_init__(self) -> None:
super().__init__()
def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None:
@@ -65,7 +65,7 @@ class UploadCheckpointAsArtifact(Callback):
ckpt_dir: Path = attr.ib(converter=Path)
upload_best_only: bool = attr.ib()
- def __attrs_pre_init__(self):
+ def __attrs_pre_init__(self) -> None:
super().__init__()
def on_train_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
@@ -90,7 +90,7 @@ class LogTextPredictions(Callback):
num_samples: int = attr.ib(default=8)
ready: bool = attr.ib(default=True)
- def __attrs_pre_init__(self):
+ def __attrs_pre_init__(self) -> None:
super().__init__()
def on_sanity_check_start(
diff --git a/text_recognizer/criterions/__init__.py b/text_recognizer/criterions/__init__.py
new file mode 100644
index 0000000..5b0a7ab
--- /dev/null
+++ b/text_recognizer/criterions/__init__.py
@@ -0,0 +1 @@
+"""Module with custom loss functions."""
diff --git a/text_recognizer/networks/loss/label_smoothing_loss.py b/text_recognizer/criterions/label_smoothing_loss.py
index 40a7609..40a7609 100644
--- a/text_recognizer/networks/loss/label_smoothing_loss.py
+++ b/text_recognizer/criterions/label_smoothing_loss.py
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index 8b5c188..de5628f 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -2,7 +2,8 @@
from pathlib import Path
from typing import Dict
-import pytorch_lightning as pl
+import attr
+import pytorch_lightning as LightningDataModule
from torch.utils.data import DataLoader
@@ -14,14 +15,17 @@ def load_and_print_info(data_module_class: type) -> None:
print(dataset)
-class BaseDataModule(pl.LightningDataModule):
+@attr.s
+class BaseDataModule(LightningDataModule):
"""Base PyTorch Lightning DataModule."""
- def __init__(self, batch_size: int = 128, num_workers: int = 0) -> None:
+ batch_size: int = attr.ib(default=16)
+ num_workers: int = attr.ib(default=0)
+
+ def __attrs_pre_init__(self) -> None:
super().__init__()
- self.batch_size = batch_size
- self.num_workers = num_workers
+ def __attrs_post_init__(self) -> None:
# Placeholders for subclasses.
self.dims = None
self.output_dims = None
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index 8d644d4..4318dfb 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -1,11 +1,13 @@
"""Base PyTorch Dataset class."""
from typing import Any, Callable, Dict, Sequence, Tuple, Union
+import attr
import torch
from torch import Tensor
from torch.utils.data import Dataset
+@attr.s
class BaseDataset(Dataset):
"""
Base Dataset class that processes data and targets through optional transfroms.
@@ -18,19 +20,17 @@ class BaseDataset(Dataset):
target transforms.
"""
- def __init__(
- self,
- data: Union[Sequence, Tensor],
- targets: Union[Sequence, Tensor],
- transform: Callable = None,
- target_transform: Callable = None,
- ) -> None:
- if len(data) != len(targets):
+ data: Union[Sequence, Tensor] = attr.ib()
+ targets: Union[Sequence, Tensor] = attr.ib()
+ transform: Callable = attr.ib()
+ target_transform: Callable = attr.ib()
+
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ def __attrs_post_init__(self) -> None:
+ if len(self.data) != len(self.targets):
raise ValueError("Data and targets must be of equal length.")
- self.data = data
- self.targets = targets
- self.transform = transform
- self.target_transform = target_transform
def __len__(self) -> int:
"""Return the length of the dataset."""
diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py
index 5ac2510..1982daf 100644
--- a/text_recognizer/models/__init__.py
+++ b/text_recognizer/models/__init__.py
@@ -1,3 +1 @@
"""PyTorch Lightning models modules."""
-from .transformer import LitTransformerModel
-from .vqvae import LitVQVAEModel
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 4e803eb..8dc7a36 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -1,5 +1,5 @@
"""Base PyTorch Lightning model."""
-from typing import Any, Dict, List, Union, Tuple, Type
+from typing import Any, Dict, List, Tuple, Type
import attr
import hydra
@@ -13,7 +13,7 @@ import torchmetrics
@attr.s
-class LitBaseModel(pl.LightningModule):
+class BaseLitModel(pl.LightningModule):
"""Abstract PyTorch Lightning class."""
network: Type[nn.Module] = attr.ib()
@@ -30,18 +30,17 @@ class LitBaseModel(pl.LightningModule):
val_acc = attr.ib(init=False)
test_acc = attr.ib(init=False)
- def __attrs_pre_init__(self):
+ def __attrs_pre_init__(self) -> None:
super().__init__()
- def __attrs_post_init__(self):
- self.loss_fn = self.configure_criterion()
+ def __attrs_post_init__(self) -> None:
+ self.loss_fn = self._configure_criterion()
# Accuracy metric
self.train_acc = torchmetrics.Accuracy()
self.val_acc = torchmetrics.Accuracy()
self.test_acc = torchmetrics.Accuracy()
- @staticmethod
def configure_criterion(self) -> Type[nn.Module]:
"""Returns a loss functions."""
log.info(f"Instantiating criterion <{self.criterion_config._target_}>")
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 6be0ac5..ea54d83 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,27 +1,35 @@
"""PyTorch Lightning model for base Transformers."""
from typing import Dict, List, Optional, Union, Tuple, Type
+import attr
from omegaconf import DictConfig
from torch import nn, Tensor
from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import LitBaseModel
-class LitTransformerModel(LitBaseModel):
+@attr.s
+class TransformerLitModel(LitBaseModel):
"""A PyTorch Lightning model for transformer networks."""
- def __init__(
- self,
- network: Type[nn.Module],
- optimizer: Union[DictConfig, Dict],
- lr_scheduler: Union[DictConfig, Dict],
- criterion: Union[DictConfig, Dict],
- monitor: str = "val_loss",
- mapping: Optional[List[str]] = None,
- ) -> None:
- super().__init__(network, optimizer, lr_scheduler, criterion, monitor)
+ network: Type[nn.Module] = attr.ib()
+ criterion_config: DictConfig = attr.ib(converter=DictConfig)
+ optimizer_config: DictConfig = attr.ib(converter=DictConfig)
+ lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)
+ monitor: str = attr.ib()
+ mapping: Type[AbstractMapping] = attr.ib()
+
+ def __attrs_post_init__(self) -> None:
+ super().__init__(
+ network=self.network,
+ optimizer_config=self.optimizer_config,
+ lr_scheduler_config=self.lr_scheduler_config,
+ criterion_config=self.criterion_config,
+ monitor=self.monitor,
+ )
self.mapping, ignore_tokens = self.configure_mapping(mapping)
self.val_cer = CharacterErrorRate(ignore_tokens)
self.test_cer = CharacterErrorRate(ignore_tokens)
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
index 18e8691..7dc950f 100644
--- a/text_recognizer/models/vqvae.py
+++ b/text_recognizer/models/vqvae.py
@@ -18,7 +18,7 @@ class LitVQVAEModel(LitBaseModel):
optimizer: Union[DictConfig, Dict],
lr_scheduler: Union[DictConfig, Dict],
criterion: Union[DictConfig, Dict],
- monitor: str = "val_loss",
+ monitor: str = "val/loss",
*args: Any,
**kwargs: Dict,
) -> None:
@@ -50,7 +50,7 @@ class LitVQVAEModel(LitBaseModel):
reconstructions, vq_loss = self.network(data)
loss = self.loss_fn(reconstructions, data)
loss += vq_loss
- self.log("train_loss", loss)
+ self.log("train/loss", loss)
return loss
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
@@ -59,7 +59,7 @@ class LitVQVAEModel(LitBaseModel):
reconstructions, vq_loss = self.network(data)
loss = self.loss_fn(reconstructions, data)
loss += vq_loss
- self.log("val_loss", loss, prog_bar=True)
+ self.log("val/loss", loss, prog_bar=True)
title = "val_pred_examples"
self._log_prediction(data, reconstructions, title)
diff --git a/text_recognizer/networks/cnn_tranformer.py b/text_recognizer/networks/cnn_tranformer.py
new file mode 100644
index 0000000..da69311
--- /dev/null
+++ b/text_recognizer/networks/cnn_tranformer.py
@@ -0,0 +1,14 @@
+"""Vision transformer for character recognition."""
+from typing import Type
+
+import attr
+from torch import nn, Tensor
+
+
+@attr.s
+class CnnTransformer(nn.Module):
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
+
+ backbone: Type[nn.Module] = attr.ib()
+ head = Type[nn.Module] = attr.ib()
diff --git a/text_recognizer/networks/loss/__init__.py b/text_recognizer/networks/loss/__init__.py
deleted file mode 100644
index cb83608..0000000
--- a/text_recognizer/networks/loss/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-"""Loss module."""
-from .loss import LabelSmoothingCrossEntropy
diff --git a/text_recognizer/networks/util.py b/text_recognizer/networks/util.py
index 05b10a8..109bf4d 100644
--- a/text_recognizer/networks/util.py
+++ b/text_recognizer/networks/util.py
@@ -1,10 +1,6 @@
"""Miscellaneous neural network functionality."""
-import importlib
-from pathlib import Path
-from typing import Dict, NamedTuple, Union, Type
+from typing import Type
-from loguru import logger
-import torch
from torch import nn
@@ -19,6 +15,7 @@ def activation_function(activation: str) -> Type[nn.Module]:
["none", nn.Identity()],
["relu", nn.ReLU(inplace=True)],
["selu", nn.SELU(inplace=True)],
+ ["mish", nn.Mish(inplace=True)],
]
)
return activation_fns[activation.lower()]