summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-02 21:13:48 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-02 21:13:48 +0200
commit75801019981492eedf9280cb352eea3d8e99b65f (patch)
tree6521cc4134459e42591b2375f70acd348741474e
parente5eca28438cd17d436359f2c6eee0bb9e55d2a8b (diff)
Fix log import, fix mapping in datamodules, fix nn modules can be hashed
-rw-r--r--notebooks/03-look-at-iam-paragraphs.ipynb18
-rw-r--r--notebooks/05c-test-model-end-to-end.ipynb152
-rw-r--r--text_recognizer/data/base_data_module.py14
-rw-r--r--text_recognizer/data/base_dataset.py11
-rw-r--r--text_recognizer/data/download_utils.py8
-rw-r--r--text_recognizer/data/emnist.py21
-rw-r--r--text_recognizer/data/emnist_lines.py21
-rw-r--r--text_recognizer/data/iam.py4
-rw-r--r--text_recognizer/data/iam_extended_paragraphs.py6
-rw-r--r--text_recognizer/data/iam_lines.py21
-rw-r--r--text_recognizer/data/iam_paragraphs.py18
-rw-r--r--text_recognizer/data/iam_preprocessor.py16
-rw-r--r--text_recognizer/data/iam_synthetic_paragraphs.py19
-rw-r--r--text_recognizer/data/make_wordpieces.py8
-rw-r--r--text_recognizer/data/mappings.py24
-rw-r--r--text_recognizer/models/base.py2
-rw-r--r--text_recognizer/models/metrics.py4
-rw-r--r--text_recognizer/models/transformer.py16
-rw-r--r--text_recognizer/models/vqvae.py2
-rw-r--r--text_recognizer/networks/conv_transformer.py3
-rw-r--r--text_recognizer/networks/encoders/efficientnet/efficientnet.py15
-rw-r--r--text_recognizer/networks/encoders/efficientnet/mbconv.py139
-rw-r--r--text_recognizer/networks/transformer/attention.py7
-rw-r--r--text_recognizer/networks/transformer/layers.py6
-rw-r--r--training/conf/callbacks/wandb.yaml20
-rw-r--r--training/conf/callbacks/wandb/checkpoints.yaml4
-rw-r--r--training/conf/callbacks/wandb/code.yaml3
-rw-r--r--training/conf/callbacks/wandb/image_reconstructions.yaml0
-rw-r--r--training/conf/callbacks/wandb/ocr_predictions.yaml3
-rw-r--r--training/conf/callbacks/wandb/watch.yaml4
-rw-r--r--training/conf/callbacks/wandb_ocr.yaml6
-rw-r--r--training/conf/config.yaml18
-rw-r--r--training/conf/criterion/label_smoothing.yaml2
-rw-r--r--training/conf/hydra/default.yaml6
-rw-r--r--training/conf/mapping/word_piece.yaml (renamed from training/conf/model/mapping/word_piece.yaml)4
-rw-r--r--training/conf/model/lit_transformer.yaml5
-rw-r--r--training/conf/network/conv_transformer.yaml2
-rw-r--r--training/conf/network/decoder/transformer_decoder.yaml4
-rw-r--r--training/conf/trainer/default.yaml6
-rw-r--r--training/run.py11
-rw-r--r--training/utils.py2
41 files changed, 307 insertions, 348 deletions
diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb
index 5e3a872..76ca6b1 100644
--- a/notebooks/03-look-at-iam-paragraphs.ipynb
+++ b/notebooks/03-look-at-iam-paragraphs.ipynb
@@ -5,7 +5,21 @@
"execution_count": 1,
"id": "6ce2519f",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "ename": "ModuleNotFoundError",
+ "evalue": "No module named 'loguru.logger'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m/tmp/ipykernel_3883/2979229631.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'..'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miam_paragraphs\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIAMParagraphs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 19\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miam_synthetic_paragraphs\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIAMSyntheticParagraphs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miam_extended_paragraphs\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIAMExtendedParagraphs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/projects/text-recognizer/text_recognizer/data/iam_paragraphs.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0memnist\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0memnist_mapping\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miam\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIAM\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmappings\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWordPieceMapping\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 23\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransforms\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWordPiece\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;32m~/projects/text-recognizer/text_recognizer/data/mappings.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mattr\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mloguru\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogger\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mlog\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
+ "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'loguru.logger'"
+ ]
+ }
+ ],
"source": [
"import os\n",
"os.environ['CUDA_VISIBLE_DEVICE'] = ''\n",
@@ -31,7 +45,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"id": "726ac25b",
"metadata": {},
"outputs": [],
diff --git a/notebooks/05c-test-model-end-to-end.ipynb b/notebooks/05c-test-model-end-to-end.ipynb
index a96e484..b652bdd 100644
--- a/notebooks/05c-test-model-end-to-end.ipynb
+++ b/notebooks/05c-test-model-end-to-end.ipynb
@@ -26,16 +26,6 @@
{
"cell_type": "code",
"execution_count": 2,
- "id": "3e812a1e",
- "metadata": {},
- "outputs": [],
- "source": [
- "import attr"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
"id": "d3a6146b-94b1-4618-a4e4-00f8e23ffdb0",
"metadata": {},
"outputs": [],
@@ -47,193 +37,163 @@
},
{
"cell_type": "code",
- "execution_count": 4,
- "id": "8741a844-3b97-47c4-a2a1-5a268d40923c",
+ "execution_count": 3,
+ "id": "6b722ca0-9c65-4f90-be4e-b7334ea81237",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "_target_: text_recognizer.data.mappings.WordPieceMapping\n",
- "num_features: 1000\n",
- "tokens: iamdb_1kwp_tokens_1000.txt\n",
- "lexicon: iamdb_1kwp_lex_1000.txt\n",
- "data_dir: null\n",
- "use_words: false\n",
- "prepend_wordsep: false\n",
- "special_tokens:\n",
+ "mapping:\n",
+ " _target_: text_recognizer.data.mappings.WordPieceMapping\n",
+ " num_features: 1000\n",
+ " tokens: iamdb_1kwp_tokens_1000.txt\n",
+ " lexicon: iamdb_1kwp_lex_1000.txt\n",
+ " data_dir: null\n",
+ " use_words: false\n",
+ " prepend_wordsep: false\n",
+ " special_tokens:\n",
+ " - <s>\n",
+ " - <e>\n",
+ " - <p>\n",
+ " extra_symbols:\n",
+ " - \\n\n",
+ "_target_: text_recognizer.models.transformer.TransformerLitModel\n",
+ "interval: step\n",
+ "monitor: val/loss\n",
+ "ignore_tokens:\n",
"- <s>\n",
"- <e>\n",
"- <p>\n",
- "extra_symbols:\n",
- "- '\n",
+ "start_token: <s>\n",
+ "end_token: <e>\n",
+ "pad_token: <p>\n",
"\n",
- " '\n",
- "\n",
- "{'_target_': 'text_recognizer.data.mappings.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['<s>', '<e>', '<p>'], 'extra_symbols': ['\\n']}\n"
+ "{'mapping': {'_target_': 'text_recognizer.data.mappings.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['<s>', '<e>', '<p>'], 'extra_symbols': ['\\\\n']}, '_target_': 'text_recognizer.models.transformer.TransformerLitModel', 'interval': 'step', 'monitor': 'val/loss', 'ignore_tokens': ['<s>', '<e>', '<p>'], 'start_token': '<s>', 'end_token': '<e>', 'pad_token': '<p>'}\n"
]
}
],
"source": [
"# context initialization\n",
- "with initialize(config_path=\"../training/conf/model/mapping\", job_name=\"test_app\"):\n",
- " cfg = compose(config_name=\"word_piece\")\n",
+ "with initialize(config_path=\"../training/conf/model/\", job_name=\"test_app\"):\n",
+ " cfg = compose(config_name=\"lit_transformer\")\n",
" print(OmegaConf.to_yaml(cfg))\n",
" print(cfg)"
]
},
{
"cell_type": "code",
- "execution_count": 5,
- "id": "c9271d46-37b1-4d06-a603-46b5ed82f821",
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "2021-07-30 23:08:27.495 | DEBUG | text_recognizer.data.mappings:__attrs_post_init__:89 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n"
- ]
- }
- ],
- "source": [
- "tt =instantiate(cfg)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "id": "bf1b07ac-9de7-4d24-a36b-09847bc6bc6f",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "WordPieceMapping(extra_symbols={'\\n'}, mapping=['<b>', '<s>', '<e>', '<p>', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', ' ', '!', '\"', '#', '&', \"'\", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '?', '\\n'], inverse_mapping={'<b>': 0, '<s>': 1, '<e>': 2, '<p>': 3, '0': 4, '1': 5, '2': 6, '3': 7, '4': 8, '5': 9, '6': 10, '7': 11, '8': 12, '9': 13, 'A': 14, 'B': 15, 'C': 16, 'D': 17, 'E': 18, 'F': 19, 'G': 20, 'H': 21, 'I': 22, 'J': 23, 'K': 24, 'L': 25, 'M': 26, 'N': 27, 'O': 28, 'P': 29, 'Q': 30, 'R': 31, 'S': 32, 'T': 33, 'U': 34, 'V': 35, 'W': 36, 'X': 37, 'Y': 38, 'Z': 39, 'a': 40, 'b': 41, 'c': 42, 'd': 43, 'e': 44, 'f': 45, 'g': 46, 'h': 47, 'i': 48, 'j': 49, 'k': 50, 'l': 51, 'm': 52, 'n': 53, 'o': 54, 'p': 55, 'q': 56, 'r': 57, 's': 58, 't': 59, 'u': 60, 'v': 61, 'w': 62, 'x': 63, 'y': 64, 'z': 65, ' ': 66, '!': 67, '\"': 68, '#': 69, '&': 70, \"'\": 71, '(': 72, ')': 73, '*': 74, '+': 75, ',': 76, '-': 77, '.': 78, '/': 79, ':': 80, ';': 81, '?': 82, '\\n': 83}, input_size=[28, 28], data_dir=PosixPath('/home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb'), num_features=1000, tokens='iamdb_1kwp_tokens_1000.txt', lexicon='iamdb_1kwp_lex_1000.txt', use_words=False, prepend_wordsep=False, special_tokens={'<p>', '<s>', '<e>'}, wordpiece_processor=<text_recognizer.data.iam_preprocessor.Preprocessor object at 0x7fa4ec7ea610>)"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "tt"
- ]
- },
- {
- "cell_type": "code",
"execution_count": null,
- "id": "2452e8f4-cc5f-4763-9a25-4fa27b7f143e",
+ "id": "9c797159-845e-42c6-bd65-1c976ad627cd",
"metadata": {},
"outputs": [],
"source": [
- "tt.mapping"
+ "# context initialization\n",
+ "with initialize(config_path=\"../training/conf/network/\", job_name=\"test_app\"):\n",
+ " cfg = compose(config_name=\"conv_transformer\")\n",
+ " print(OmegaConf.to_yaml(cfg))\n",
+ " print(cfg)"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "6b722ca0-9c65-4f90-be4e-b7334ea81237",
+ "id": "af2c8cfa-0b45-4681-b671-0f97ace62516",
"metadata": {},
"outputs": [],
"source": [
- "# context initialization\n",
- "with initialize(config_path=\"../training/conf/model/\", job_name=\"test_app\"):\n",
- " cfg = compose(config_name=\"lit_transformer\")\n",
- " print(OmegaConf.to_yaml(cfg))\n",
- " print(cfg)"
+ "net = instantiate(cfg)"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "9c797159-845e-42c6-bd65-1c976ad627cd",
- "metadata": {},
+ "id": "8f0742ad-5e2f-42d5-83e7-6e46398b4f0f",
+ "metadata": {
+ "tags": []
+ },
"outputs": [],
"source": [
- "# context initialization\n",
- "with initialize(config_path=\"../training/conf/network/\", job_name=\"test_app\"):\n",
- " cfg = compose(config_name=\"conv_transformer\")\n",
- " print(OmegaConf.to_yaml(cfg))\n",
- " print(cfg)"
+ "net"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "dcfbe2ab-6775-4aa4-acf4-57203a3f5511",
+ "id": "40be59bc-db79-4af1-9df4-e280f7a56481",
"metadata": {},
"outputs": [],
"source": [
- "from importlib import import_module"
+ "img = torch.rand(4, 1, 576, 640)"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "e3d4c70e-509d-457a-ac81-2bac27cb95d2",
+ "id": "d5a8f10b-edf5-4a18-9747-f016db72c384",
"metadata": {},
"outputs": [],
"source": [
- "x = import_module(\"text_recognizer.networks.transformer.attention\")"
+ "y = torch.randint(0, 1006, (4, 451))"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "039d4a7f-f70d-43a1-8b5f-7e766ac01010",
+ "id": "19423ef1-3d98-4af3-8748-fdd3bb817300",
"metadata": {},
"outputs": [],
"source": [
- "y = partial(getattr(x, \"Attention\"), dim=16, num_heads=2, **cfg.decoder.attn_kwargs)"
+ "y.shape"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "9be1d661-bfac-4826-ab8d-453557713f68",
+ "id": "0712ee7e-4f66-4fb1-bc91-d8a127eb7ac7",
"metadata": {},
"outputs": [],
"source": [
- "y().causal"
+ "net = net.cuda()\n",
+ "img = img.cuda()\n",
+ "y = y.cuda()"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "54b35e6f-35db-4769-8bc5-ed1764768cf2",
+ "id": "719154b4-47db-4c91-bae4-8c572c4a4536",
"metadata": {},
"outputs": [],
"source": [
- "y(causal=True)"
+ "net(img, y).shape"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "af2c8cfa-0b45-4681-b671-0f97ace62516",
+ "id": "bcb7db0f-0afe-44eb-9bb7-b988fbead95a",
"metadata": {},
"outputs": [],
"source": [
- "net = instantiate(cfg)"
+ "from torchsummary import summary"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "8f0742ad-5e2f-42d5-83e7-6e46398b4f0f",
+ "id": "31af8ee1-28d3-46b8-a847-6506d29bc45c",
"metadata": {},
"outputs": [],
"source": [
- "net"
+ "summary(net, [(1, 576, 640), (451,)], device=\"cpu\", depth=2)"
]
},
{
"cell_type": "code",
"execution_count": null,
- "id": "709be6cc-6708-4561-ad45-28f433612a0d",
+ "id": "4d6d836f-d169-48b4-92e6-ca17179e6f85",
"metadata": {},
"outputs": [],
"source": []
diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py
index 408ae36..fd914b6 100644
--- a/text_recognizer/data/base_data_module.py
+++ b/text_recognizer/data/base_data_module.py
@@ -1,11 +1,12 @@
"""Base lightning DataModule class."""
from pathlib import Path
-from typing import Any, Dict, Tuple
+from typing import Dict, Tuple
import attr
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
+from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.data.base_dataset import BaseDataset
@@ -24,8 +25,10 @@ class BaseDataModule(LightningDataModule):
def __attrs_pre_init__(self) -> None:
super().__init__()
+ mapping: AbstractMapping = attr.ib()
batch_size: int = attr.ib(default=16)
num_workers: int = attr.ib(default=0)
+ pin_memory: bool = attr.ib(default=True)
# Placeholders
data_train: BaseDataset = attr.ib(init=False, default=None)
@@ -33,8 +36,6 @@ class BaseDataModule(LightningDataModule):
data_test: BaseDataset = attr.ib(init=False, default=None)
dims: Tuple[int, ...] = attr.ib(init=False, default=None)
output_dims: Tuple[int, ...] = attr.ib(init=False, default=None)
- mapping: Any = attr.ib(init=False, default=None)
- inverse_mapping: Dict[str, int] = attr.ib(init=False)
@classmethod
def data_dirname(cls) -> Path:
@@ -46,7 +47,6 @@ class BaseDataModule(LightningDataModule):
return {
"input_dim": self.dims,
"output_dims": self.output_dims,
- "mapping": self.mapping,
}
def prepare_data(self) -> None:
@@ -72,7 +72,7 @@ class BaseDataModule(LightningDataModule):
shuffle=True,
batch_size=self.batch_size,
num_workers=self.num_workers,
- pin_memory=True,
+ pin_memory=self.pin_memory,
)
def val_dataloader(self) -> DataLoader:
@@ -82,7 +82,7 @@ class BaseDataModule(LightningDataModule):
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
- pin_memory=True,
+ pin_memory=self.pin_memory,
)
def test_dataloader(self) -> DataLoader:
@@ -92,5 +92,5 @@ class BaseDataModule(LightningDataModule):
shuffle=False,
batch_size=self.batch_size,
num_workers=self.num_workers,
- pin_memory=True,
+ pin_memory=self.pin_memory,
)
diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py
index c26f1c9..8640d92 100644
--- a/text_recognizer/data/base_dataset.py
+++ b/text_recognizer/data/base_dataset.py
@@ -1,5 +1,5 @@
"""Base PyTorch Dataset class."""
-from typing import Any, Callable, Dict, Sequence, Tuple, Union
+from typing import Callable, Dict, Optional, Sequence, Tuple, Union
import attr
import torch
@@ -22,14 +22,13 @@ class BaseDataset(Dataset):
data: Union[Sequence, Tensor] = attr.ib()
targets: Union[Sequence, Tensor] = attr.ib()
- transform: Callable = attr.ib()
- target_transform: Callable = attr.ib()
+ transform: Optional[Callable] = attr.ib(default=None)
+ target_transform: Optional[Callable] = attr.ib(default=None)
def __attrs_pre_init__(self) -> None:
super().__init__()
def __attrs_post_init__(self) -> None:
- # TODO: refactor this
if len(self.data) != len(self.targets):
raise ValueError("Data and targets must be of equal length.")
@@ -37,14 +36,14 @@ class BaseDataset(Dataset):
"""Return the length of the dataset."""
return len(self.data)
- def __getitem__(self, index: int) -> Tuple[Any, Any]:
+ def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]:
"""Return a datum and its target, after processing by transforms.
Args:
index (int): Index of a datum in the dataset.
Returns:
- Tuple[Any, Any]: Datum and target pair.
+ Tuple[Tensor, Tensor]: Datum and target pair.
"""
datum, target = self.data[index], self.targets[index]
diff --git a/text_recognizer/data/download_utils.py b/text_recognizer/data/download_utils.py
index e3dc68c..8938830 100644
--- a/text_recognizer/data/download_utils.py
+++ b/text_recognizer/data/download_utils.py
@@ -4,7 +4,7 @@ from pathlib import Path
from typing import Dict, List, Optional
from urllib.request import urlretrieve
-from loguru import logger
+from loguru import logger as log
from tqdm import tqdm
@@ -32,7 +32,7 @@ class TqdmUpTo(tqdm):
total_size (Optional[int]): Total size in tqdm units. Defaults to None.
"""
if total_size is not None:
- self.total = total_size # pylint: disable=attribute-defined-outside-init
+ self.total = total_size
self.update(blocks * block_size - self.n)
@@ -62,9 +62,9 @@ def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]:
filename = dl_dir / metadata["filename"]
if filename.exists():
return
- logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...")
+ log.info(f"Downloading raw dataset from {metadata['url']} to {filename}...")
_download_url(metadata["url"], filename)
- logger.info("Computing the SHA-256...")
+ log.info("Computing the SHA-256...")
sha256 = _compute_sha256(filename)
if sha256 != metadata["sha256"]:
raise ValueError(
diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py
index 2d0ac29..c6be123 100644
--- a/text_recognizer/data/emnist.py
+++ b/text_recognizer/data/emnist.py
@@ -3,12 +3,12 @@ import json
import os
from pathlib import Path
import shutil
-from typing import Callable, Dict, List, Optional, Sequence, Tuple
+from typing import Callable, Dict, List, Optional, Set, Sequence, Tuple
import zipfile
import attr
import h5py
-from loguru import logger
+from loguru import logger as log
import numpy as np
import toml
import torchvision.transforms as T
@@ -50,8 +50,7 @@ class EMNIST(BaseDataModule):
transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()]))
def __attrs_post_init__(self) -> None:
- self.mapping, self.inverse_mapping, input_shape = emnist_mapping()
- self.dims = (1, *input_shape)
+ self.dims = (1, *self.mapping.input_size)
def prepare_data(self) -> None:
"""Downloads dataset if not present."""
@@ -106,7 +105,7 @@ class EMNIST(BaseDataModule):
def emnist_mapping(
- extra_symbols: Optional[Sequence[str]] = None,
+ extra_symbols: Optional[Set[str]] = None,
) -> Tuple[List, Dict[str, int], List[int]]:
"""Return the EMNIST mapping."""
if not ESSENTIALS_FILENAME.exists():
@@ -130,7 +129,7 @@ def download_and_process_emnist() -> None:
def _process_raw_dataset(filename: str, dirname: Path) -> None:
"""Processes the raw EMNIST dataset."""
- logger.info("Unzipping EMNIST...")
+ log.info("Unzipping EMNIST...")
curdir = os.getcwd()
os.chdir(dirname)
content = zipfile.ZipFile(filename, "r")
@@ -138,7 +137,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
from scipy.io import loadmat
- logger.info("Loading training data from .mat file")
+ log.info("Loading training data from .mat file")
data = loadmat("matlab/emnist-byclass.mat")
x_train = (
data["dataset"]["train"][0, 0]["images"][0, 0]
@@ -152,11 +151,11 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS
if SAMPLE_TO_BALANCE:
- logger.info("Balancing classes to reduce amount of data")
+ log.info("Balancing classes to reduce amount of data")
x_train, y_train = _sample_to_balance(x_train, y_train)
x_test, y_test = _sample_to_balance(x_test, y_test)
- logger.info("Saving to HDF5 in a compressed format...")
+ log.info("Saving to HDF5 in a compressed format...")
PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
with h5py.File(PROCESSED_DATA_FILENAME, "w") as f:
f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf")
@@ -164,7 +163,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf")
f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf")
- logger.info("Saving essential dataset parameters to text_recognizer/datasets...")
+ log.info("Saving essential dataset parameters to text_recognizer/datasets...")
mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]}
characters = _augment_emnist_characters(mapping.values())
essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])}
@@ -172,7 +171,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None:
with ESSENTIALS_FILENAME.open(mode="w") as f:
json.dump(essentials, f)
- logger.info("Cleaning up...")
+ log.info("Cleaning up...")
shutil.rmtree("matlab")
os.chdir(curdir)
diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py
index 7548ad5..5298726 100644
--- a/text_recognizer/data/emnist_lines.py
+++ b/text_recognizer/data/emnist_lines.py
@@ -1,11 +1,11 @@
"""Dataset of generated text from EMNIST characters."""
from collections import defaultdict
from pathlib import Path
-from typing import Callable, Dict, Tuple
+from typing import Callable, List, Tuple
import attr
import h5py
-from loguru import logger
+from loguru import logger as log
import numpy as np
import torch
from torchvision import transforms
@@ -46,8 +46,7 @@ class EMNISTLines(BaseDataModule):
emnist: EMNIST = attr.ib(init=False, default=None)
def __attrs_post_init__(self) -> None:
- self.emnist = EMNIST()
- self.mapping = self.emnist.mapping
+ self.emnist = EMNIST(mapping=self.mapping)
max_width = (
int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap))
@@ -86,7 +85,7 @@ class EMNISTLines(BaseDataModule):
self._generate_data("test")
def setup(self, stage: str = None) -> None:
- logger.info("EMNISTLinesDataset loading data from HDF5...")
+ log.info("EMNISTLinesDataset loading data from HDF5...")
if stage == "fit" or stage is None:
print(self.data_filename)
with h5py.File(self.data_filename, "r") as f:
@@ -137,7 +136,7 @@ class EMNISTLines(BaseDataModule):
return basic + data
def _generate_data(self, split: str) -> None:
- logger.info(f"EMNISTLines generating data for {split}...")
+ log.info(f"EMNISTLines generating data for {split}...")
sentence_generator = SentenceGenerator(
self.max_length - 2
) # Subtract by 2 because start/end token
@@ -148,17 +147,17 @@ class EMNISTLines(BaseDataModule):
if split == "train":
samples_by_char = _get_samples_by_char(
- emnist.x_train, emnist.y_train, emnist.mapping
+ emnist.x_train, emnist.y_train, self.mapping.mapping
)
num = self.num_train
elif split == "val":
samples_by_char = _get_samples_by_char(
- emnist.x_train, emnist.y_train, emnist.mapping
+ emnist.x_train, emnist.y_train, self.mapping.mapping
)
num = self.num_val
else:
samples_by_char = _get_samples_by_char(
- emnist.x_test, emnist.y_test, emnist.mapping
+ emnist.x_test, emnist.y_test, self.mapping.mapping
)
num = self.num_test
@@ -173,14 +172,14 @@ class EMNISTLines(BaseDataModule):
self.dims,
)
y = convert_strings_to_labels(
- y, emnist.inverse_mapping, length=MAX_OUTPUT_LENGTH
+ y, self.mapping.inverse_mapping, length=MAX_OUTPUT_LENGTH
)
f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf")
f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf")
def _get_samples_by_char(
- samples: np.ndarray, labels: np.ndarray, mapping: Dict
+ samples: np.ndarray, labels: np.ndarray, mapping: List
) -> defaultdict:
samples_by_char = defaultdict(list)
for sample, label in zip(samples, labels):
diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py
index 3982c4f..7278eb2 100644
--- a/text_recognizer/data/iam.py
+++ b/text_recognizer/data/iam.py
@@ -7,7 +7,7 @@ import zipfile
import attr
from boltons.cacheutils import cachedproperty
-from loguru import logger
+from loguru import logger as log
import toml
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
@@ -92,7 +92,7 @@ class IAM(BaseDataModule):
def _extract_raw_dataset(filename: Path, dirname: Path) -> None:
- logger.info("Extracting IAM data...")
+ log.info("Extracting IAM data...")
curdir = os.getcwd()
os.chdir(dirname)
with zipfile.ZipFile(filename, "r") as f:
diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py
index 0e97801..ccf0759 100644
--- a/text_recognizer/data/iam_extended_paragraphs.py
+++ b/text_recognizer/data/iam_extended_paragraphs.py
@@ -4,7 +4,6 @@ from typing import Dict, List
import attr
from torch.utils.data import ConcatDataset
-from text_recognizer.data.base_dataset import BaseDataset
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
from text_recognizer.data.iam_paragraphs import IAMParagraphs
from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs
@@ -20,6 +19,7 @@ class IAMExtendedParagraphs(BaseDataModule):
def __attrs_post_init__(self) -> None:
self.iam_paragraphs = IAMParagraphs(
+ mapping=self.mapping,
batch_size=self.batch_size,
num_workers=self.num_workers,
train_fraction=self.train_fraction,
@@ -27,6 +27,7 @@ class IAMExtendedParagraphs(BaseDataModule):
word_pieces=self.word_pieces,
)
self.iam_synthetic_paragraphs = IAMSyntheticParagraphs(
+ mapping=self.mapping,
batch_size=self.batch_size,
num_workers=self.num_workers,
train_fraction=self.train_fraction,
@@ -36,7 +37,6 @@ class IAMExtendedParagraphs(BaseDataModule):
self.dims = self.iam_paragraphs.dims
self.output_dims = self.iam_paragraphs.output_dims
- self.num_classes = self.iam_paragraphs.num_classes
def prepare_data(self) -> None:
"""Prepares the paragraphs data."""
@@ -58,7 +58,7 @@ class IAMExtendedParagraphs(BaseDataModule):
"""Returns info about the dataset."""
basic = (
"IAM Original and Synthetic Paragraphs Dataset\n" # pylint: disable=no-member
- f"Num classes: {len(self.num_classes)}\n"
+ f"Num classes: {len(self.mapping)}\n"
f"Dims: {self.dims}\n"
f"Output dims: {self.output_dims}\n"
)
diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py
index b7f3fdd..1c63729 100644
--- a/text_recognizer/data/iam_lines.py
+++ b/text_recognizer/data/iam_lines.py
@@ -2,15 +2,14 @@
If not created, will generate a handwritten lines dataset from the IAM paragraphs
dataset.
-
"""
import json
from pathlib import Path
import random
-from typing import Dict, List, Sequence, Tuple
+from typing import List, Sequence, Tuple
import attr
-from loguru import logger
+from loguru import logger as log
from PIL import Image, ImageFile, ImageOps
import numpy as np
from torch import Tensor
@@ -23,7 +22,7 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
-from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data import image_utils
@@ -48,17 +47,13 @@ class IAMLines(BaseDataModule):
)
output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1))
- def __attrs_post_init__(self) -> None:
- # TODO: refactor this
- self.mapping, self.inverse_mapping, _ = emnist_mapping()
-
def prepare_data(self) -> None:
"""Creates the IAM lines dataset if not existing."""
if PROCESSED_DATA_DIRNAME.exists():
return
- logger.info("Cropping IAM lines regions...")
- iam = IAM()
+ log.info("Cropping IAM lines regions...")
+ iam = IAM(mapping=EmnistMapping())
iam.prepare_data()
crops_train, labels_train = line_crops_and_labels(iam, "train")
crops_test, labels_test = line_crops_and_labels(iam, "test")
@@ -66,7 +61,7 @@ class IAMLines(BaseDataModule):
shapes = np.array([crop.size for crop in crops_train + crops_test])
aspect_ratios = shapes[:, 0] / shapes[:, 1]
- logger.info("Saving images, labels, and statistics...")
+ log.info("Saving images, labels, and statistics...")
save_images_and_labels(
crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME
)
@@ -91,7 +86,7 @@ class IAMLines(BaseDataModule):
raise ValueError("Target length longer than max output length.")
y_train = convert_strings_to_labels(
- labels_train, self.inverse_mapping, length=self.output_dims[0]
+ labels_train, self.mapping.inverse_mapping, length=self.output_dims[0]
)
data_train = BaseDataset(
x_train, y_train, transform=get_transform(IMAGE_WIDTH, self.augment)
@@ -110,7 +105,7 @@ class IAMLines(BaseDataModule):
raise ValueError("Taget length longer than max output length.")
y_test = convert_strings_to_labels(
- labels_test, self.inverse_mapping, length=self.output_dims[0]
+ labels_test, self.mapping.inverse_mapping, length=self.output_dims[0]
)
self.data_test = BaseDataset(
x_test, y_test, transform=get_transform(IMAGE_WIDTH)
diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py
index 0f3a2ce..6189f7d 100644
--- a/text_recognizer/data/iam_paragraphs.py
+++ b/text_recognizer/data/iam_paragraphs.py
@@ -4,7 +4,7 @@ from pathlib import Path
from typing import Dict, List, Optional, Sequence, Tuple
import attr
-from loguru import logger
+from loguru import logger as log
import numpy as np
from PIL import Image, ImageOps
import torchvision.transforms as T
@@ -17,9 +17,8 @@ from text_recognizer.data.base_dataset import (
split_dataset,
)
from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info
-from text_recognizer.data.emnist import emnist_mapping
+from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.iam import IAM
-from text_recognizer.data.mappings import WordPieceMapping
from text_recognizer.data.transforms import WordPiece
@@ -38,7 +37,6 @@ MAX_LABEL_LENGTH = 682
class IAMParagraphs(BaseDataModule):
"""IAM handwriting database paragraphs."""
- num_classes: int = attr.ib()
word_pieces: bool = attr.ib(default=False)
augment: bool = attr.ib(default=True)
train_fraction: float = attr.ib(default=0.8)
@@ -46,21 +44,17 @@ class IAMParagraphs(BaseDataModule):
init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH)
)
output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1))
- inverse_mapping: Dict[str, int] = attr.ib(init=False)
-
- def __attrs_post_init__(self) -> None:
- _, self.inverse_mapping, _ = emnist_mapping(extra_symbols=[NEW_LINE_TOKEN])
def prepare_data(self) -> None:
"""Create data for training/testing."""
if PROCESSED_DATA_DIRNAME.exists():
return
- logger.info(
+ log.info(
"Cropping IAM paragraph regions and saving them along with labels..."
)
- iam = IAM()
+ iam = IAM(mapping=EmnistMapping())
iam.prepare_data()
properties = {}
@@ -89,7 +83,7 @@ class IAMParagraphs(BaseDataModule):
crops, labels = _load_processed_crops_and_labels(split)
data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops]
targets = convert_strings_to_labels(
- strings=labels, mapping=self.inverse_mapping, length=self.output_dims[0]
+ strings=labels, mapping=self.mapping.inverse_mapping, length=self.output_dims[0]
)
return BaseDataset(
data,
@@ -98,7 +92,7 @@ class IAMParagraphs(BaseDataModule):
target_transform=get_target_transform(self.word_pieces),
)
- logger.info(f"Loading IAM paragraph regions and lines for {stage}...")
+ log.info(f"Loading IAM paragraph regions and lines for {stage}...")
_validate_data_dims(input_dims=self.dims, output_dims=self.output_dims)
if stage == "fit" or stage is None:
diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py
index 93a13bb..bcd77b4 100644
--- a/text_recognizer/data/iam_preprocessor.py
+++ b/text_recognizer/data/iam_preprocessor.py
@@ -1,18 +1,16 @@
"""Preprocessor for extracting word letters from the IAM dataset.
The code is mostly stolen from:
-
https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py
-
"""
import collections
import itertools
from pathlib import Path
import re
-from typing import List, Optional, Union, Sequence
+from typing import List, Optional, Union, Set
import click
-from loguru import logger
+from loguru import logger as log
import torch
@@ -57,7 +55,7 @@ class Preprocessor:
lexicon_path: Optional[Union[str, Path]] = None,
use_words: bool = False,
prepend_wordsep: bool = False,
- special_tokens: Optional[Sequence[str]] = None,
+ special_tokens: Optional[Set[str]] = None,
) -> None:
self.wordsep = "▁"
self._use_word = use_words
@@ -186,7 +184,7 @@ def cli(
/ "iam"
/ "iamdb"
)
- logger.debug(f"Using data dir: {data_dir}")
+ log.debug(f"Using data dir: {data_dir}")
if not data_dir.exists():
raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
else:
@@ -196,15 +194,15 @@ def cli(
preprocessor.extract_train_text()
processed_dir = data_dir.parents[2] / "processed" / "iam_lines"
- logger.debug(f"Saving processed files at: {processed_dir}")
+ log.debug(f"Saving processed files at: {processed_dir}")
if save_text is not None:
- logger.info("Saving training text")
+ log.info("Saving training text")
with open(processed_dir / save_text, "w") as f:
f.write("\n".join(t for t in preprocessor.text))
if save_tokens is not None:
- logger.info("Saving tokens")
+ log.info("Saving tokens")
with open(processed_dir / save_tokens, "w") as f:
f.write("\n".join(preprocessor.tokens))
diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py
index f00a494..c938f8b 100644
--- a/text_recognizer/data/iam_synthetic_paragraphs.py
+++ b/text_recognizer/data/iam_synthetic_paragraphs.py
@@ -3,7 +3,7 @@ import random
from typing import Any, List, Sequence, Tuple
import attr
-from loguru import logger
+from loguru import logger as log
import numpy as np
from PIL import Image
@@ -21,6 +21,7 @@ from text_recognizer.data.iam_paragraphs import (
IMAGE_SCALE_FACTOR,
resize_image,
)
+from text_recognizer.data.mappings import EmnistMapping
from text_recognizer.data.iam import IAM
from text_recognizer.data.iam_lines import (
line_crops_and_labels,
@@ -43,10 +44,10 @@ class IAMSyntheticParagraphs(IAMParagraphs):
if PROCESSED_DATA_DIRNAME.exists():
return
- logger.info("Preparing IAM lines for synthetic paragraphs dataset.")
- logger.info("Cropping IAM line regions and loading labels.")
+ log.info("Preparing IAM lines for synthetic paragraphs dataset.")
+ log.info("Cropping IAM line regions and loading labels.")
- iam = IAM()
+ iam = IAM(mapping=EmnistMapping())
iam.prepare_data()
crops_train, labels_train = line_crops_and_labels(iam, "train")
@@ -55,7 +56,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
crops_train = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_train]
crops_test = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_test]
- logger.info(f"Saving images and labels at {PROCESSED_DATA_DIRNAME}")
+ log.info(f"Saving images and labels at {PROCESSED_DATA_DIRNAME}")
save_images_and_labels(
crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME
)
@@ -64,7 +65,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
def setup(self, stage: str = None) -> None:
"""Loading synthetic dataset."""
- logger.info(f"IAM Synthetic dataset steup for stage {stage}...")
+ log.info(f"IAM Synthetic dataset steup for stage {stage}...")
if stage == "fit" or stage is None:
line_crops, line_labels = load_line_crops_and_labels(
@@ -76,7 +77,7 @@ class IAMSyntheticParagraphs(IAMParagraphs):
targets = convert_strings_to_labels(
strings=paragraphs_labels,
- mapping=self.inverse_mapping,
+ mapping=self.mapping.inverse_mapping,
length=self.output_dims[0],
)
self.data_train = BaseDataset(
@@ -144,7 +145,7 @@ def generate_synthetic_paragraphs(
[line_labels[i] for i in paragraph_indices]
)
if len(paragraph_label) > paragraphs_properties["label_length"]["max"]:
- logger.info(
+ log.info(
"Label longer than longest label in original IAM paragraph dataset - hence dropping."
)
continue
@@ -158,7 +159,7 @@ def generate_synthetic_paragraphs(
paragraph_crop.height > max_paragraph_shape[0]
or paragraph_crop.width > max_paragraph_shape[1]
):
- logger.info(
+ log.info(
"Crop larger than largest crop in original IAM paragraphs dataset - hence dropping"
)
continue
diff --git a/text_recognizer/data/make_wordpieces.py b/text_recognizer/data/make_wordpieces.py
index ef9eb1b..40fbee4 100644
--- a/text_recognizer/data/make_wordpieces.py
+++ b/text_recognizer/data/make_wordpieces.py
@@ -10,7 +10,7 @@ from pathlib import Path
from typing import List, Optional, Union
import click
-from loguru import logger
+from loguru import logger as log
import sentencepiece as spm
from text_recognizer.data.iam_preprocessor import load_metadata
@@ -63,9 +63,9 @@ def save_pieces(
vocab: set,
) -> None:
"""Saves word pieces to disk."""
- logger.info(f"Generating word piece list of size {num_pieces}.")
+ log.info(f"Generating word piece list of size {num_pieces}.")
pieces = [sp.id_to_piece(i) for i in range(1, num_pieces + 1)]
- logger.info(f"Encoding vocabulary of size {len(vocab)}.")
+ log.info(f"Encoding vocabulary of size {len(vocab)}.")
encoded_vocab = [sp.encode_as_pieces(v) for v in vocab]
# Save pieces to file.
@@ -101,7 +101,7 @@ def cli(
data_dir = (
Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines"
)
- logger.debug(f"Using data dir: {data_dir}")
+ log.debug(f"Using data dir: {data_dir}")
if not data_dir.exists():
raise RuntimeError(f"Could not locate iamdb directory at {data_dir}")
else:
diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py
index b69e888..d1c64dd 100644
--- a/text_recognizer/data/mappings.py
+++ b/text_recognizer/data/mappings.py
@@ -1,18 +1,30 @@
"""Mapping to and from word pieces."""
from abc import ABC, abstractmethod
from pathlib import Path
-from typing import Dict, List, Optional, Union, Set, Sequence
+from typing import Dict, List, Optional, Union, Set
import attr
-import loguru.logger as log
import torch
+from loguru import logger as log
from torch import Tensor
from text_recognizer.data.emnist import emnist_mapping
from text_recognizer.data.iam_preprocessor import Preprocessor
+@attr.s
class AbstractMapping(ABC):
+ input_size: List[int] = attr.ib(init=False)
+ mapping: List[str] = attr.ib(init=False)
+ inverse_mapping: Dict[str, int] = attr.ib(init=False)
+
+ def __len__(self) -> int:
+ return len(self.mapping)
+
+ @property
+ def num_classes(self) -> int:
+ return self.__len__()
+
@abstractmethod
def get_token(self, *args, **kwargs) -> str:
...
@@ -30,15 +42,13 @@ class AbstractMapping(ABC):
...
-@attr.s
+@attr.s(auto_attribs=True)
class EmnistMapping(AbstractMapping):
- extra_symbols: Optional[Set[str]] = attr.ib(default=None, converter=set)
- mapping: Sequence[str] = attr.ib(init=False)
- inverse_mapping: Dict[str, int] = attr.ib(init=False)
- input_size: List[int] = attr.ib(init=False)
+ extra_symbols: Optional[Set[str]] = attr.ib(default=None)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
+ self.extra_symbols = set(self.extra_symbols) if self.extra_symbols is not None else None
self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(
self.extra_symbols
)
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index caf63c1..8ce5c37 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -12,7 +12,7 @@ from torch import Tensor
import torchmetrics
-@attr.s
+@attr.s(eq=False)
class BaseLitModel(LightningModule):
"""Abstract PyTorch Lightning class."""
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py
index 0eb42dc..f83c9e4 100644
--- a/text_recognizer/models/metrics.py
+++ b/text_recognizer/models/metrics.py
@@ -8,11 +8,11 @@ from torch import Tensor
from torchmetrics import Metric
-@attr.s
+@attr.s(eq=False)
class CharacterErrorRate(Metric):
"""Character error rate metric, computed using Levenshtein distance."""
- ignore_indices: Set = attr.ib(converter=set)
+ ignore_indices: Set[Tensor] = attr.ib(converter=set)
error: Tensor = attr.ib(init=False)
total: Tensor = attr.ib(init=False)
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 0e01bb5..91e088d 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,5 +1,5 @@
"""PyTorch Lightning model for base Transformers."""
-from typing import Sequence, Tuple, Type
+from typing import Tuple, Type, Set
import attr
import torch
@@ -10,20 +10,20 @@ from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import BaseLitModel
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, eq=False)
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
- mapping: Type[AbstractMapping] = attr.ib()
- start_token: str = attr.ib()
- end_token: str = attr.ib()
- pad_token: str = attr.ib()
+ mapping: Type[AbstractMapping] = attr.ib(default=None)
+ start_token: str = attr.ib(default="<s>")
+ end_token: str = attr.ib(default="<e>")
+ pad_token: str = attr.ib(default="<p>")
start_index: Tensor = attr.ib(init=False)
end_index: Tensor = attr.ib(init=False)
pad_index: Tensor = attr.ib(init=False)
- ignore_indices: Sequence[str] = attr.ib(init=False)
+ ignore_indices: Set[Tensor] = attr.ib(init=False)
val_cer: CharacterErrorRate = attr.ib(init=False)
test_cer: CharacterErrorRate = attr.ib(init=False)
@@ -32,7 +32,7 @@ class TransformerLitModel(BaseLitModel):
self.start_index = self.mapping.get_index(self.start_token)
self.end_index = self.mapping.get_index(self.end_token)
self.pad_index = self.mapping.get_index(self.pad_token)
- self.ignore_indices = [self.start_index, self.end_index, self.pad_index]
+ self.ignore_indices = set([self.start_index, self.end_index, self.pad_index])
self.val_cer = CharacterErrorRate(self.ignore_indices)
self.test_cer = CharacterErrorRate(self.ignore_indices)
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
index e215e14..22da018 100644
--- a/text_recognizer/models/vqvae.py
+++ b/text_recognizer/models/vqvae.py
@@ -10,7 +10,7 @@ import wandb
from text_recognizer.models.base import BaseLitModel
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, eq=False)
class VQVAELitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py
index 7371be4..09cc654 100644
--- a/text_recognizer/networks/conv_transformer.py
+++ b/text_recognizer/networks/conv_transformer.py
@@ -13,7 +13,7 @@ from text_recognizer.networks.transformer.positional_encodings import (
)
-@attr.s
+@attr.s(eq=False)
class ConvTransformer(nn.Module):
"""Convolutional encoder and transformer decoder network."""
@@ -121,6 +121,7 @@ class ConvTransformer(nn.Module):
Returns:
Tensor: Sequence of word piece embeddings.
"""
+ context = context.long()
context_mask = context != self.pad_index
context = self.token_embedding(context) * math.sqrt(self.hidden_dim)
context = self.token_pos_encoder(context)
diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
index a36150a..b8eb53b 100644
--- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py
+++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py
@@ -1,4 +1,4 @@
-"""Efficient net."""
+"""Efficientnet backbone."""
from typing import Tuple
import attr
@@ -12,8 +12,10 @@ from .utils import (
)
-@attr.s
+@attr.s(eq=False)
class EfficientNet(nn.Module):
+ """Efficientnet without classification head."""
+
def __attrs_pre_init__(self) -> None:
super().__init__()
@@ -47,11 +49,13 @@ class EfficientNet(nn.Module):
@arch.validator
def check_arch(self, attribute: attr._make.Attribute, value: str) -> None:
+ """Validates the efficientnet architecure."""
if value not in self.archs:
raise ValueError(f"{value} not a valid architecure.")
self.params = self.archs[value]
def _build(self) -> None:
+ """Builds the efficientnet backbone."""
_block_args = block_args()
in_channels = 1 # BW
out_channels = round_filters(32, self.params)
@@ -73,8 +77,9 @@ class EfficientNet(nn.Module):
for args in _block_args:
args.in_channels = round_filters(args.in_channels, self.params)
args.out_channels = round_filters(args.out_channels, self.params)
- args.num_repeats = round_repeats(args.num_repeats, self.params)
- for _ in range(args.num_repeats):
+ num_repeats = round_repeats(args.num_repeats, self.params)
+ del args.num_repeats
+ for _ in range(num_repeats):
self._blocks.append(
MBConvBlock(
**args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps,
@@ -93,6 +98,7 @@ class EfficientNet(nn.Module):
)
def extract_features(self, x: Tensor) -> Tensor:
+ """Extracts the final feature map layer."""
x = self._conv_stem(x)
for i, block in enumerate(self._blocks):
stochastic_dropout_rate = self.stochastic_dropout_rate
@@ -103,4 +109,5 @@ class EfficientNet(nn.Module):
return x
def forward(self, x: Tensor) -> Tensor:
+ """Returns efficientnet image features."""
return self.extract_features(x)
diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py
index 3aa63d0..e85df87 100644
--- a/text_recognizer/networks/encoders/efficientnet/mbconv.py
+++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py
@@ -1,76 +1,62 @@
"""Mobile inverted residual block."""
-from typing import Any, Optional, Union, Tuple
+from typing import Optional, Sequence, Union, Tuple
+import attr
import torch
from torch import nn, Tensor
import torch.nn.functional as F
-from .utils import stochastic_depth
+from text_recognizer.networks.encoders.efficientnet.utils import stochastic_depth
+def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]:
+ """Converts int to tuple."""
+ return (
+ (stride,) * 2 if isinstance(stride, int) else stride
+ )
+
+
+@attr.s(eq=False)
class MBConvBlock(nn.Module):
"""Mobile Inverted Residual Bottleneck block."""
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- stride: Union[Tuple[int, int], int],
- bn_momentum: float,
- bn_eps: float,
- se_ratio: float,
- expand_ratio: int,
- *args: Any,
- **kwargs: Any,
- ) -> None:
+ def __attrs_pre_init__(self) -> None:
super().__init__()
- self.kernel_size = kernel_size
- self.stride = (stride,) * 2 if isinstance(stride, int) else stride
- self.bn_momentum = bn_momentum
- self.bn_eps = bn_eps
- self.in_channels = in_channels
- self.out_channels = out_channels
+ in_channels: int = attr.ib()
+ out_channels: int = attr.ib()
+ kernel_size: Tuple[int, int] = attr.ib()
+ stride: Tuple[int, int] = attr.ib(converter=_convert_stride)
+ bn_momentum: float = attr.ib()
+ bn_eps: float = attr.ib()
+ se_ratio: float = attr.ib()
+ expand_ratio: int = attr.ib()
+ pad: Tuple[int, int, int, int] = attr.ib(init=False)
+ _inverted_bottleneck: nn.Sequential = attr.ib(init=False)
+ _depthwise: nn.Sequential = attr.ib(init=False)
+ _squeeze_excite: nn.Sequential = attr.ib(init=False)
+ _pointwise: nn.Sequential = attr.ib(init=False)
+
+ @pad.default
+ def _configure_padding(self) -> Tuple[int, int, int, int]:
+ """Set padding for convolutional layers."""
if self.stride == (2, 2):
- self.pad = [
+ return (
(self.kernel_size - 1) // 2 - 1,
(self.kernel_size - 1) // 2,
- ] * 2
- else:
- self.pad = [(self.kernel_size - 1) // 2] * 4
-
- # Placeholders for layers.
- self._inverted_bottleneck: nn.Sequential = None
- self._depthwise: nn.Sequential = None
- self._squeeze_excite: nn.Sequential = None
- self._pointwise: nn.Sequential = None
-
- self._build(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
- expand_ratio=expand_ratio,
- se_ratio=se_ratio,
- )
+ ) * 2
+ return ((self.kernel_size - 1) // 2,) * 4
+
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+ self._build()
- def _build(
- self,
- in_channels: int,
- out_channels: int,
- kernel_size: int,
- stride: Union[Tuple[int, int], int],
- expand_ratio: int,
- se_ratio: float,
- ) -> None:
- has_se = se_ratio is not None and 0.0 < se_ratio < 1.0
- inner_channels = in_channels * expand_ratio
+ def _build(self) -> None:
+ has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0
+ inner_channels = self.in_channels * self.expand_ratio
self._inverted_bottleneck = (
- self._configure_inverted_bottleneck(
- in_channels=in_channels, out_channels=inner_channels,
- )
- if expand_ratio != 1
+ self._configure_inverted_bottleneck(out_channels=inner_channels)
+ if self.expand_ratio != 1
else None
)
@@ -78,31 +64,23 @@ class MBConvBlock(nn.Module):
in_channels=inner_channels,
out_channels=inner_channels,
groups=inner_channels,
- kernel_size=kernel_size,
- stride=stride,
)
self._squeeze_excite = (
self._configure_squeeze_excite(
- in_channels=inner_channels,
- out_channels=inner_channels,
- se_ratio=se_ratio,
+ in_channels=inner_channels, out_channels=inner_channels,
)
if has_se
else None
)
- self._pointwise = self._configure_pointwise(
- in_channels=inner_channels, out_channels=out_channels
- )
+ self._pointwise = self._configure_pointwise(in_channels=inner_channels)
- def _configure_inverted_bottleneck(
- self, in_channels: int, out_channels: int,
- ) -> nn.Sequential:
+ def _configure_inverted_bottleneck(self, out_channels: int) -> nn.Sequential:
"""Expansion phase."""
return nn.Sequential(
nn.Conv2d(
- in_channels=in_channels,
+ in_channels=self.in_channels,
out_channels=out_channels,
kernel_size=1,
bias=False,
@@ -114,19 +92,14 @@ class MBConvBlock(nn.Module):
)
def _configure_depthwise(
- self,
- in_channels: int,
- out_channels: int,
- groups: int,
- kernel_size: int,
- stride: Union[Tuple[int, int], int],
+ self, in_channels: int, out_channels: int, groups: int,
) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
groups=groups,
bias=False,
),
@@ -137,9 +110,9 @@ class MBConvBlock(nn.Module):
)
def _configure_squeeze_excite(
- self, in_channels: int, out_channels: int, se_ratio: float
+ self, in_channels: int, out_channels: int
) -> nn.Sequential:
- num_squeezed_channels = max(1, int(in_channels * se_ratio))
+ num_squeezed_channels = max(1, int(in_channels * self.se_ratio))
return nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
@@ -154,18 +127,18 @@ class MBConvBlock(nn.Module):
),
)
- def _configure_pointwise(
- self, in_channels: int, out_channels: int
- ) -> nn.Sequential:
+ def _configure_pointwise(self, in_channels: int) -> nn.Sequential:
return nn.Sequential(
nn.Conv2d(
in_channels=in_channels,
- out_channels=out_channels,
+ out_channels=self.out_channels,
kernel_size=1,
bias=False,
),
nn.BatchNorm2d(
- num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps
+ num_features=self.out_channels,
+ momentum=self.bn_momentum,
+ eps=self.bn_eps,
),
)
@@ -186,8 +159,8 @@ class MBConvBlock(nn.Module):
residual = x
if self._inverted_bottleneck is not None:
x = self._inverted_bottleneck(x)
- x = F.pad(x, self.pad)
+ x = F.pad(x, self.pad)
x = self._depthwise(x)
if self._squeeze_excite is not None:
diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index 9202cce..37ce29e 100644
--- a/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
@@ -15,7 +15,7 @@ from text_recognizer.networks.transformer.positional_encodings.rotary_embedding
)
-@attr.s
+@attr.s(eq=False)
class Attention(nn.Module):
"""Standard attention."""
@@ -31,7 +31,6 @@ class Attention(nn.Module):
dropout: nn.Dropout = attr.ib(init=False)
fc: nn.Linear = attr.ib(init=False)
qkv_fn: nn.Sequential = attr.ib(init=False)
- attn_fn: F.softmax = attr.ib(init=False, default=F.softmax)
def __attrs_post_init__(self) -> None:
"""Post init configuration."""
@@ -80,7 +79,7 @@ class Attention(nn.Module):
else k_mask
)
q_mask = rearrange(q_mask, "b i -> b () i ()")
- k_mask = rearrange(k_mask, "b i -> b () () j")
+ k_mask = rearrange(k_mask, "b j -> b () () j")
return q_mask * k_mask
return
@@ -129,7 +128,7 @@ class Attention(nn.Module):
if self.causal:
energy = self._apply_causal_mask(energy, mask, mask_value, device)
- attn = self.attn_fn(energy, dim=-1)
+ attn = F.softmax(energy, dim=-1)
attn = self.dropout(attn)
out = einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py
index 66c9c50..ce443e5 100644
--- a/text_recognizer/networks/transformer/layers.py
+++ b/text_recognizer/networks/transformer/layers.py
@@ -12,7 +12,7 @@ from text_recognizer.networks.transformer.positional_encodings.rotary_embedding
from text_recognizer.networks.util import load_partial_fn
-@attr.s
+@attr.s(eq=False)
class AttentionLayers(nn.Module):
"""Standard transfomer layer."""
@@ -101,11 +101,11 @@ class AttentionLayers(nn.Module):
return x
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, eq=False)
class Encoder(AttentionLayers):
causal: bool = attr.ib(default=False, init=False)
-@attr.s(auto_attribs=True)
+@attr.s(auto_attribs=True, eq=False)
class Decoder(AttentionLayers):
causal: bool = attr.ib(default=True, init=False)
diff --git a/training/conf/callbacks/wandb.yaml b/training/conf/callbacks/wandb.yaml
deleted file mode 100644
index 0017e11..0000000
--- a/training/conf/callbacks/wandb.yaml
+++ /dev/null
@@ -1,20 +0,0 @@
-defaults:
- - default.yaml
-
-watch_model:
- _target_: callbacks.wandb_callbacks.WatchModel
- log: all
- log_freq: 100
-
-upload_code_as_artifact:
- _target_: callbacks.wandb_callbacks.UploadCodeAsArtifact
- project_dir: ${work_dir}/text_recognizer
-
-upload_ckpts_as_artifact:
- _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
- ckpt_dir: checkpoints/
- upload_best_only: true
-
-log_text_predictions:
- _target_: callbacks.wandb_callbacks.LogTextPredictions
- num_samples: 8
diff --git a/training/conf/callbacks/wandb/checkpoints.yaml b/training/conf/callbacks/wandb/checkpoints.yaml
new file mode 100644
index 0000000..a4a16ff
--- /dev/null
+++ b/training/conf/callbacks/wandb/checkpoints.yaml
@@ -0,0 +1,4 @@
+upload_ckpts_as_artifact:
+ _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
+ ckpt_dir: checkpoints/
+ upload_best_only: true
diff --git a/training/conf/callbacks/wandb/code.yaml b/training/conf/callbacks/wandb/code.yaml
new file mode 100644
index 0000000..35f6ea3
--- /dev/null
+++ b/training/conf/callbacks/wandb/code.yaml
@@ -0,0 +1,3 @@
+upload_code_as_artifact:
+ _target_: callbacks.wandb_callbacks.UploadCodeAsArtifact
+ project_dir: ${work_dir}/text_recognizer
diff --git a/training/conf/callbacks/wandb/image_reconstructions.yaml b/training/conf/callbacks/wandb/image_reconstructions.yaml
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/training/conf/callbacks/wandb/image_reconstructions.yaml
diff --git a/training/conf/callbacks/wandb/ocr_predictions.yaml b/training/conf/callbacks/wandb/ocr_predictions.yaml
new file mode 100644
index 0000000..573fa96
--- /dev/null
+++ b/training/conf/callbacks/wandb/ocr_predictions.yaml
@@ -0,0 +1,3 @@
+log_text_predictions:
+ _target_: callbacks.wandb_callbacks.LogTextPredictions
+ num_samples: 8
diff --git a/training/conf/callbacks/wandb/watch.yaml b/training/conf/callbacks/wandb/watch.yaml
new file mode 100644
index 0000000..511608c
--- /dev/null
+++ b/training/conf/callbacks/wandb/watch.yaml
@@ -0,0 +1,4 @@
+watch_model:
+ _target_: callbacks.wandb_callbacks.WatchModel
+ log: all
+ log_freq: 100
diff --git a/training/conf/callbacks/wandb_ocr.yaml b/training/conf/callbacks/wandb_ocr.yaml
new file mode 100644
index 0000000..efa3dda
--- /dev/null
+++ b/training/conf/callbacks/wandb_ocr.yaml
@@ -0,0 +1,6 @@
+defaults:
+ - default
+ - wandb/watch
+ - wandb/code
+ - wandb/checkpoints
+ - wandb/ocr_predictions
diff --git a/training/conf/config.yaml b/training/conf/config.yaml
index a8e718e..93215ed 100644
--- a/training/conf/config.yaml
+++ b/training/conf/config.yaml
@@ -1,19 +1,17 @@
defaults:
- - network: vqvae
- - criterion: mse
- - optimizer: madgrad
- - lr_scheduler: one_cycle
- - model: lit_vqvae
+ - callbacks: wandb_ocr
+ - criterion: label_smoothing
- dataset: iam_extended_paragraphs
+ - hydra: default
+ - lr_scheduler: one_cycle
+ - mapping: word_piece
+ - model: lit_transformer
+ - network: conv_transformer
+ - optimizer: madgrad
- trainer: default
- - callbacks:
- - checkpoint
- - learning_rate_monitor
seed: 4711
-wandb: false
tune: false
train: true
test: true
-load_checkpoint: null
logging: INFO
diff --git a/training/conf/criterion/label_smoothing.yaml b/training/conf/criterion/label_smoothing.yaml
index ee47c59..13daba8 100644
--- a/training/conf/criterion/label_smoothing.yaml
+++ b/training/conf/criterion/label_smoothing.yaml
@@ -1,4 +1,4 @@
-_target_: text_recognizer.criterion.label_smoothing
+_target_: text_recognizer.criterion.label_smoothing.LabelSmoothingLoss
label_smoothing: 0.1
vocab_size: 1006
ignore_index: 1002
diff --git a/training/conf/hydra/default.yaml b/training/conf/hydra/default.yaml
new file mode 100644
index 0000000..dfd9721
--- /dev/null
+++ b/training/conf/hydra/default.yaml
@@ -0,0 +1,6 @@
+# output paths for hydra logs
+run:
+ dir: logs/runs/${now:%Y-%m-%d}/${now:%H-%M-%S}
+sweep:
+ dir: logs/multiruns/${now:%Y-%m-%d_%H-%M-%S}
+ subdir: ${hydra.job.num}
diff --git a/training/conf/model/mapping/word_piece.yaml b/training/conf/mapping/word_piece.yaml
index 39e2ba4..3792523 100644
--- a/training/conf/model/mapping/word_piece.yaml
+++ b/training/conf/mapping/word_piece.yaml
@@ -5,5 +5,5 @@ lexicon: iamdb_1kwp_lex_1000.txt
data_dir: null
use_words: false
prepend_wordsep: false
-special_tokens: ["<s>", "<e>", "<p>"]
-extra_symbols: ["\n"]
+special_tokens: [ <s>, <e>, <p> ]
+extra_symbols: [ \n ]
diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml
index 5341d8e..6ffde4e 100644
--- a/training/conf/model/lit_transformer.yaml
+++ b/training/conf/model/lit_transformer.yaml
@@ -1,8 +1,5 @@
-defaults:
- - mapping: word_piece
-
_target_: text_recognizer.models.transformer.TransformerLitModel
-interval: null
+interval: step
monitor: val/loss
ignore_tokens: [ <s>, <e>, <p> ]
start_token: <s>
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
index 7d57a2d..a97157d 100644
--- a/training/conf/network/conv_transformer.yaml
+++ b/training/conf/network/conv_transformer.yaml
@@ -4,7 +4,7 @@ defaults:
_target_: text_recognizer.networks.conv_transformer.ConvTransformer
input_dims: [1, 576, 640]
-hidden_dim: 256
+hidden_dim: 96
dropout_rate: 0.2
max_output_len: 451
num_classes: 1006
diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml
index 3122de1..90b9d8a 100644
--- a/training/conf/network/decoder/transformer_decoder.yaml
+++ b/training/conf/network/decoder/transformer_decoder.yaml
@@ -2,12 +2,12 @@ defaults:
- rotary_emb: null
_target_: text_recognizer.networks.transformer.Decoder
-dim: 256
+dim: 96
depth: 2
num_heads: 8
attn_fn: text_recognizer.networks.transformer.attention.Attention
attn_kwargs:
- dim_head: 64
+ dim_head: 16
dropout_rate: 0.2
norm_fn: torch.nn.LayerNorm
ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml
index 5ed6552..c665adc 100644
--- a/training/conf/trainer/default.yaml
+++ b/training/conf/trainer/default.yaml
@@ -6,6 +6,10 @@ gradient_clip_val: 0
fast_dev_run: false
gpus: 1
precision: 16
-max_epochs: 64
+max_epochs: 512
terminate_on_nan: true
weights_summary: top
+limit_train_batches: 1.0
+limit_val_batches: 1.0
+limit_test_batches: 1.0
+resume_from_checkpoint: null
diff --git a/training/run.py b/training/run.py
index d88a8f6..30479c6 100644
--- a/training/run.py
+++ b/training/run.py
@@ -2,7 +2,7 @@
from typing import List, Optional, Type
import hydra
-import loguru.logger as log
+from loguru import logger as log
from omegaconf import DictConfig
from pytorch_lightning import (
Callback,
@@ -12,6 +12,7 @@ from pytorch_lightning import (
Trainer,
)
from pytorch_lightning.loggers import LightningLoggerBase
+from text_recognizer.data.mappings import AbstractMapping
from torch import nn
import utils
@@ -25,15 +26,19 @@ def run(config: DictConfig) -> Optional[float]:
if config.get("seed"):
seed_everything(config.seed)
+ log.info(f"Instantiating mapping <{config.mapping._target_}>")
+ mapping: AbstractMapping = hydra.utils.instantiate(config.mapping)
+
log.info(f"Instantiating datamodule <{config.datamodule._target_}>")
- datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule)
+ datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule, mapping=mapping)
log.info(f"Instantiating network <{config.network._target_}>")
- network: nn.Module = hydra.utils.instantiate(config.network, **datamodule.config())
+ network: nn.Module = hydra.utils.instantiate(config.network)
log.info(f"Instantiating model <{config.model._target_}>")
model: LightningModule = hydra.utils.instantiate(
**config.model,
+ mapping=mapping,
network=network,
criterion_config=config.criterion,
optimizer_config=config.optimizer,
diff --git a/training/utils.py b/training/utils.py
index 564b9bb..ef74f61 100644
--- a/training/utils.py
+++ b/training/utils.py
@@ -3,7 +3,7 @@ from typing import Any, List, Type
import warnings
import hydra
-import loguru.logger as log
+from loguru import logger as log
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import (
Callback,