From 75801019981492eedf9280cb352eea3d8e99b65f Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Mon, 2 Aug 2021 21:13:48 +0200 Subject: Fix log import, fix mapping in datamodules, fix nn modules can be hashed --- notebooks/03-look-at-iam-paragraphs.ipynb | 18 ++- notebooks/05c-test-model-end-to-end.ipynb | 152 ++++++++------------- text_recognizer/data/base_data_module.py | 14 +- text_recognizer/data/base_dataset.py | 11 +- text_recognizer/data/download_utils.py | 8 +- text_recognizer/data/emnist.py | 21 ++- text_recognizer/data/emnist_lines.py | 21 ++- text_recognizer/data/iam.py | 4 +- text_recognizer/data/iam_extended_paragraphs.py | 6 +- text_recognizer/data/iam_lines.py | 21 ++- text_recognizer/data/iam_paragraphs.py | 18 +-- text_recognizer/data/iam_preprocessor.py | 16 +-- text_recognizer/data/iam_synthetic_paragraphs.py | 19 +-- text_recognizer/data/make_wordpieces.py | 8 +- text_recognizer/data/mappings.py | 24 +++- text_recognizer/models/base.py | 2 +- text_recognizer/models/metrics.py | 4 +- text_recognizer/models/transformer.py | 16 +-- text_recognizer/models/vqvae.py | 2 +- text_recognizer/networks/conv_transformer.py | 3 +- .../networks/encoders/efficientnet/efficientnet.py | 15 +- .../networks/encoders/efficientnet/mbconv.py | 139 ++++++++----------- text_recognizer/networks/transformer/attention.py | 7 +- text_recognizer/networks/transformer/layers.py | 6 +- training/conf/callbacks/wandb.yaml | 20 --- training/conf/callbacks/wandb/checkpoints.yaml | 4 + training/conf/callbacks/wandb/code.yaml | 3 + .../callbacks/wandb/image_reconstructions.yaml | 0 training/conf/callbacks/wandb/ocr_predictions.yaml | 3 + training/conf/callbacks/wandb/watch.yaml | 4 + training/conf/callbacks/wandb_ocr.yaml | 6 + training/conf/config.yaml | 18 ++- training/conf/criterion/label_smoothing.yaml | 2 +- training/conf/hydra/default.yaml | 6 + training/conf/mapping/word_piece.yaml | 9 ++ training/conf/model/lit_transformer.yaml | 5 +- training/conf/model/mapping/word_piece.yaml | 9 -- training/conf/network/conv_transformer.yaml | 2 +- .../conf/network/decoder/transformer_decoder.yaml | 4 +- training/conf/trainer/default.yaml | 6 +- training/run.py | 11 +- training/utils.py | 2 +- 42 files changed, 314 insertions(+), 355 deletions(-) delete mode 100644 training/conf/callbacks/wandb.yaml create mode 100644 training/conf/callbacks/wandb/checkpoints.yaml create mode 100644 training/conf/callbacks/wandb/code.yaml create mode 100644 training/conf/callbacks/wandb/image_reconstructions.yaml create mode 100644 training/conf/callbacks/wandb/ocr_predictions.yaml create mode 100644 training/conf/callbacks/wandb/watch.yaml create mode 100644 training/conf/callbacks/wandb_ocr.yaml create mode 100644 training/conf/hydra/default.yaml create mode 100644 training/conf/mapping/word_piece.yaml delete mode 100644 training/conf/model/mapping/word_piece.yaml diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb index 5e3a872..76ca6b1 100644 --- a/notebooks/03-look-at-iam-paragraphs.ipynb +++ b/notebooks/03-look-at-iam-paragraphs.ipynb @@ -5,7 +5,21 @@ "execution_count": 1, "id": "6ce2519f", "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'loguru.logger'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_3883/2979229631.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0msys\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'..'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miam_paragraphs\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIAMParagraphs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 19\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miam_synthetic_paragraphs\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIAMSyntheticParagraphs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miam_extended_paragraphs\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIAMExtendedParagraphs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/projects/text-recognizer/text_recognizer/data/iam_paragraphs.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0memnist\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0memnist_mapping\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miam\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mIAM\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 22\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmappings\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWordPieceMapping\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 23\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtext_recognizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtransforms\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mWordPiece\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 24\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/projects/text-recognizer/text_recognizer/data/mappings.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mattr\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mloguru\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogger\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mlog\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mtorch\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'loguru.logger'" + ] + } + ], "source": [ "import os\n", "os.environ['CUDA_VISIBLE_DEVICE'] = ''\n", @@ -31,7 +45,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "id": "726ac25b", "metadata": {}, "outputs": [], diff --git a/notebooks/05c-test-model-end-to-end.ipynb b/notebooks/05c-test-model-end-to-end.ipynb index a96e484..b652bdd 100644 --- a/notebooks/05c-test-model-end-to-end.ipynb +++ b/notebooks/05c-test-model-end-to-end.ipynb @@ -26,16 +26,6 @@ { "cell_type": "code", "execution_count": 2, - "id": "3e812a1e", - "metadata": {}, - "outputs": [], - "source": [ - "import attr" - ] - }, - { - "cell_type": "code", - "execution_count": 3, "id": "d3a6146b-94b1-4618-a4e4-00f8e23ffdb0", "metadata": {}, "outputs": [], @@ -47,193 +37,163 @@ }, { "cell_type": "code", - "execution_count": 4, - "id": "8741a844-3b97-47c4-a2a1-5a268d40923c", + "execution_count": 3, + "id": "6b722ca0-9c65-4f90-be4e-b7334ea81237", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "_target_: text_recognizer.data.mappings.WordPieceMapping\n", - "num_features: 1000\n", - "tokens: iamdb_1kwp_tokens_1000.txt\n", - "lexicon: iamdb_1kwp_lex_1000.txt\n", - "data_dir: null\n", - "use_words: false\n", - "prepend_wordsep: false\n", - "special_tokens:\n", + "mapping:\n", + " _target_: text_recognizer.data.mappings.WordPieceMapping\n", + " num_features: 1000\n", + " tokens: iamdb_1kwp_tokens_1000.txt\n", + " lexicon: iamdb_1kwp_lex_1000.txt\n", + " data_dir: null\n", + " use_words: false\n", + " prepend_wordsep: false\n", + " special_tokens:\n", + " - \n", + " - \n", + " -

\n", + " extra_symbols:\n", + " - \\n\n", + "_target_: text_recognizer.models.transformer.TransformerLitModel\n", + "interval: step\n", + "monitor: val/loss\n", + "ignore_tokens:\n", "- \n", "- \n", "-

\n", - "extra_symbols:\n", - "- '\n", + "start_token: \n", + "end_token: \n", + "pad_token:

\n", "\n", - " '\n", - "\n", - "{'_target_': 'text_recognizer.data.mappings.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['', '', '

'], 'extra_symbols': ['\\n']}\n" + "{'mapping': {'_target_': 'text_recognizer.data.mappings.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['', '', '

'], 'extra_symbols': ['\\\\n']}, '_target_': 'text_recognizer.models.transformer.TransformerLitModel', 'interval': 'step', 'monitor': 'val/loss', 'ignore_tokens': ['', '', '

'], 'start_token': '', 'end_token': '', 'pad_token': '

'}\n" ] } ], "source": [ "# context initialization\n", - "with initialize(config_path=\"../training/conf/model/mapping\", job_name=\"test_app\"):\n", - " cfg = compose(config_name=\"word_piece\")\n", + "with initialize(config_path=\"../training/conf/model/\", job_name=\"test_app\"):\n", + " cfg = compose(config_name=\"lit_transformer\")\n", " print(OmegaConf.to_yaml(cfg))\n", " print(cfg)" ] }, - { - "cell_type": "code", - "execution_count": 5, - "id": "c9271d46-37b1-4d06-a603-46b5ed82f821", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2021-07-30 23:08:27.495 | DEBUG | text_recognizer.data.mappings:__attrs_post_init__:89 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n" - ] - } - ], - "source": [ - "tt =instantiate(cfg)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "bf1b07ac-9de7-4d24-a36b-09847bc6bc6f", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "WordPieceMapping(extra_symbols={'\\n'}, mapping=['', '', '', '

', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', ' ', '!', '\"', '#', '&', \"'\", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '?', '\\n'], inverse_mapping={'': 0, '': 1, '': 2, '

': 3, '0': 4, '1': 5, '2': 6, '3': 7, '4': 8, '5': 9, '6': 10, '7': 11, '8': 12, '9': 13, 'A': 14, 'B': 15, 'C': 16, 'D': 17, 'E': 18, 'F': 19, 'G': 20, 'H': 21, 'I': 22, 'J': 23, 'K': 24, 'L': 25, 'M': 26, 'N': 27, 'O': 28, 'P': 29, 'Q': 30, 'R': 31, 'S': 32, 'T': 33, 'U': 34, 'V': 35, 'W': 36, 'X': 37, 'Y': 38, 'Z': 39, 'a': 40, 'b': 41, 'c': 42, 'd': 43, 'e': 44, 'f': 45, 'g': 46, 'h': 47, 'i': 48, 'j': 49, 'k': 50, 'l': 51, 'm': 52, 'n': 53, 'o': 54, 'p': 55, 'q': 56, 'r': 57, 's': 58, 't': 59, 'u': 60, 'v': 61, 'w': 62, 'x': 63, 'y': 64, 'z': 65, ' ': 66, '!': 67, '\"': 68, '#': 69, '&': 70, \"'\": 71, '(': 72, ')': 73, '*': 74, '+': 75, ',': 76, '-': 77, '.': 78, '/': 79, ':': 80, ';': 81, '?': 82, '\\n': 83}, input_size=[28, 28], data_dir=PosixPath('/home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb'), num_features=1000, tokens='iamdb_1kwp_tokens_1000.txt', lexicon='iamdb_1kwp_lex_1000.txt', use_words=False, prepend_wordsep=False, special_tokens={'

', '', ''}, wordpiece_processor=)" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "tt" - ] - }, { "cell_type": "code", "execution_count": null, - "id": "2452e8f4-cc5f-4763-9a25-4fa27b7f143e", + "id": "9c797159-845e-42c6-bd65-1c976ad627cd", "metadata": {}, "outputs": [], "source": [ - "tt.mapping" + "# context initialization\n", + "with initialize(config_path=\"../training/conf/network/\", job_name=\"test_app\"):\n", + " cfg = compose(config_name=\"conv_transformer\")\n", + " print(OmegaConf.to_yaml(cfg))\n", + " print(cfg)" ] }, { "cell_type": "code", "execution_count": null, - "id": "6b722ca0-9c65-4f90-be4e-b7334ea81237", + "id": "af2c8cfa-0b45-4681-b671-0f97ace62516", "metadata": {}, "outputs": [], "source": [ - "# context initialization\n", - "with initialize(config_path=\"../training/conf/model/\", job_name=\"test_app\"):\n", - " cfg = compose(config_name=\"lit_transformer\")\n", - " print(OmegaConf.to_yaml(cfg))\n", - " print(cfg)" + "net = instantiate(cfg)" ] }, { "cell_type": "code", "execution_count": null, - "id": "9c797159-845e-42c6-bd65-1c976ad627cd", - "metadata": {}, + "id": "8f0742ad-5e2f-42d5-83e7-6e46398b4f0f", + "metadata": { + "tags": [] + }, "outputs": [], "source": [ - "# context initialization\n", - "with initialize(config_path=\"../training/conf/network/\", job_name=\"test_app\"):\n", - " cfg = compose(config_name=\"conv_transformer\")\n", - " print(OmegaConf.to_yaml(cfg))\n", - " print(cfg)" + "net" ] }, { "cell_type": "code", "execution_count": null, - "id": "dcfbe2ab-6775-4aa4-acf4-57203a3f5511", + "id": "40be59bc-db79-4af1-9df4-e280f7a56481", "metadata": {}, "outputs": [], "source": [ - "from importlib import import_module" + "img = torch.rand(4, 1, 576, 640)" ] }, { "cell_type": "code", "execution_count": null, - "id": "e3d4c70e-509d-457a-ac81-2bac27cb95d2", + "id": "d5a8f10b-edf5-4a18-9747-f016db72c384", "metadata": {}, "outputs": [], "source": [ - "x = import_module(\"text_recognizer.networks.transformer.attention\")" + "y = torch.randint(0, 1006, (4, 451))" ] }, { "cell_type": "code", "execution_count": null, - "id": "039d4a7f-f70d-43a1-8b5f-7e766ac01010", + "id": "19423ef1-3d98-4af3-8748-fdd3bb817300", "metadata": {}, "outputs": [], "source": [ - "y = partial(getattr(x, \"Attention\"), dim=16, num_heads=2, **cfg.decoder.attn_kwargs)" + "y.shape" ] }, { "cell_type": "code", "execution_count": null, - "id": "9be1d661-bfac-4826-ab8d-453557713f68", + "id": "0712ee7e-4f66-4fb1-bc91-d8a127eb7ac7", "metadata": {}, "outputs": [], "source": [ - "y().causal" + "net = net.cuda()\n", + "img = img.cuda()\n", + "y = y.cuda()" ] }, { "cell_type": "code", "execution_count": null, - "id": "54b35e6f-35db-4769-8bc5-ed1764768cf2", + "id": "719154b4-47db-4c91-bae4-8c572c4a4536", "metadata": {}, "outputs": [], "source": [ - "y(causal=True)" + "net(img, y).shape" ] }, { "cell_type": "code", "execution_count": null, - "id": "af2c8cfa-0b45-4681-b671-0f97ace62516", + "id": "bcb7db0f-0afe-44eb-9bb7-b988fbead95a", "metadata": {}, "outputs": [], "source": [ - "net = instantiate(cfg)" + "from torchsummary import summary" ] }, { "cell_type": "code", "execution_count": null, - "id": "8f0742ad-5e2f-42d5-83e7-6e46398b4f0f", + "id": "31af8ee1-28d3-46b8-a847-6506d29bc45c", "metadata": {}, "outputs": [], "source": [ - "net" + "summary(net, [(1, 576, 640), (451,)], device=\"cpu\", depth=2)" ] }, { "cell_type": "code", "execution_count": null, - "id": "709be6cc-6708-4561-ad45-28f433612a0d", + "id": "4d6d836f-d169-48b4-92e6-ca17179e6f85", "metadata": {}, "outputs": [], "source": [] diff --git a/text_recognizer/data/base_data_module.py b/text_recognizer/data/base_data_module.py index 408ae36..fd914b6 100644 --- a/text_recognizer/data/base_data_module.py +++ b/text_recognizer/data/base_data_module.py @@ -1,11 +1,12 @@ """Base lightning DataModule class.""" from pathlib import Path -from typing import Any, Dict, Tuple +from typing import Dict, Tuple import attr from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader +from text_recognizer.data.mappings import AbstractMapping from text_recognizer.data.base_dataset import BaseDataset @@ -24,8 +25,10 @@ class BaseDataModule(LightningDataModule): def __attrs_pre_init__(self) -> None: super().__init__() + mapping: AbstractMapping = attr.ib() batch_size: int = attr.ib(default=16) num_workers: int = attr.ib(default=0) + pin_memory: bool = attr.ib(default=True) # Placeholders data_train: BaseDataset = attr.ib(init=False, default=None) @@ -33,8 +36,6 @@ class BaseDataModule(LightningDataModule): data_test: BaseDataset = attr.ib(init=False, default=None) dims: Tuple[int, ...] = attr.ib(init=False, default=None) output_dims: Tuple[int, ...] = attr.ib(init=False, default=None) - mapping: Any = attr.ib(init=False, default=None) - inverse_mapping: Dict[str, int] = attr.ib(init=False) @classmethod def data_dirname(cls) -> Path: @@ -46,7 +47,6 @@ class BaseDataModule(LightningDataModule): return { "input_dim": self.dims, "output_dims": self.output_dims, - "mapping": self.mapping, } def prepare_data(self) -> None: @@ -72,7 +72,7 @@ class BaseDataModule(LightningDataModule): shuffle=True, batch_size=self.batch_size, num_workers=self.num_workers, - pin_memory=True, + pin_memory=self.pin_memory, ) def val_dataloader(self) -> DataLoader: @@ -82,7 +82,7 @@ class BaseDataModule(LightningDataModule): shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, - pin_memory=True, + pin_memory=self.pin_memory, ) def test_dataloader(self) -> DataLoader: @@ -92,5 +92,5 @@ class BaseDataModule(LightningDataModule): shuffle=False, batch_size=self.batch_size, num_workers=self.num_workers, - pin_memory=True, + pin_memory=self.pin_memory, ) diff --git a/text_recognizer/data/base_dataset.py b/text_recognizer/data/base_dataset.py index c26f1c9..8640d92 100644 --- a/text_recognizer/data/base_dataset.py +++ b/text_recognizer/data/base_dataset.py @@ -1,5 +1,5 @@ """Base PyTorch Dataset class.""" -from typing import Any, Callable, Dict, Sequence, Tuple, Union +from typing import Callable, Dict, Optional, Sequence, Tuple, Union import attr import torch @@ -22,14 +22,13 @@ class BaseDataset(Dataset): data: Union[Sequence, Tensor] = attr.ib() targets: Union[Sequence, Tensor] = attr.ib() - transform: Callable = attr.ib() - target_transform: Callable = attr.ib() + transform: Optional[Callable] = attr.ib(default=None) + target_transform: Optional[Callable] = attr.ib(default=None) def __attrs_pre_init__(self) -> None: super().__init__() def __attrs_post_init__(self) -> None: - # TODO: refactor this if len(self.data) != len(self.targets): raise ValueError("Data and targets must be of equal length.") @@ -37,14 +36,14 @@ class BaseDataset(Dataset): """Return the length of the dataset.""" return len(self.data) - def __getitem__(self, index: int) -> Tuple[Any, Any]: + def __getitem__(self, index: int) -> Tuple[Tensor, Tensor]: """Return a datum and its target, after processing by transforms. Args: index (int): Index of a datum in the dataset. Returns: - Tuple[Any, Any]: Datum and target pair. + Tuple[Tensor, Tensor]: Datum and target pair. """ datum, target = self.data[index], self.targets[index] diff --git a/text_recognizer/data/download_utils.py b/text_recognizer/data/download_utils.py index e3dc68c..8938830 100644 --- a/text_recognizer/data/download_utils.py +++ b/text_recognizer/data/download_utils.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Dict, List, Optional from urllib.request import urlretrieve -from loguru import logger +from loguru import logger as log from tqdm import tqdm @@ -32,7 +32,7 @@ class TqdmUpTo(tqdm): total_size (Optional[int]): Total size in tqdm units. Defaults to None. """ if total_size is not None: - self.total = total_size # pylint: disable=attribute-defined-outside-init + self.total = total_size self.update(blocks * block_size - self.n) @@ -62,9 +62,9 @@ def download_dataset(metadata: Dict, dl_dir: Path) -> Optional[Path]: filename = dl_dir / metadata["filename"] if filename.exists(): return - logger.info(f"Downloading raw dataset from {metadata['url']} to {filename}...") + log.info(f"Downloading raw dataset from {metadata['url']} to {filename}...") _download_url(metadata["url"], filename) - logger.info("Computing the SHA-256...") + log.info("Computing the SHA-256...") sha256 = _compute_sha256(filename) if sha256 != metadata["sha256"]: raise ValueError( diff --git a/text_recognizer/data/emnist.py b/text_recognizer/data/emnist.py index 2d0ac29..c6be123 100644 --- a/text_recognizer/data/emnist.py +++ b/text_recognizer/data/emnist.py @@ -3,12 +3,12 @@ import json import os from pathlib import Path import shutil -from typing import Callable, Dict, List, Optional, Sequence, Tuple +from typing import Callable, Dict, List, Optional, Set, Sequence, Tuple import zipfile import attr import h5py -from loguru import logger +from loguru import logger as log import numpy as np import toml import torchvision.transforms as T @@ -50,8 +50,7 @@ class EMNIST(BaseDataModule): transform: Callable = attr.ib(init=False, default=T.Compose([T.ToTensor()])) def __attrs_post_init__(self) -> None: - self.mapping, self.inverse_mapping, input_shape = emnist_mapping() - self.dims = (1, *input_shape) + self.dims = (1, *self.mapping.input_size) def prepare_data(self) -> None: """Downloads dataset if not present.""" @@ -106,7 +105,7 @@ class EMNIST(BaseDataModule): def emnist_mapping( - extra_symbols: Optional[Sequence[str]] = None, + extra_symbols: Optional[Set[str]] = None, ) -> Tuple[List, Dict[str, int], List[int]]: """Return the EMNIST mapping.""" if not ESSENTIALS_FILENAME.exists(): @@ -130,7 +129,7 @@ def download_and_process_emnist() -> None: def _process_raw_dataset(filename: str, dirname: Path) -> None: """Processes the raw EMNIST dataset.""" - logger.info("Unzipping EMNIST...") + log.info("Unzipping EMNIST...") curdir = os.getcwd() os.chdir(dirname) content = zipfile.ZipFile(filename, "r") @@ -138,7 +137,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: from scipy.io import loadmat - logger.info("Loading training data from .mat file") + log.info("Loading training data from .mat file") data = loadmat("matlab/emnist-byclass.mat") x_train = ( data["dataset"]["train"][0, 0]["images"][0, 0] @@ -152,11 +151,11 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: y_test = data["dataset"]["test"][0, 0]["labels"][0, 0] + NUM_SPECIAL_TOKENS if SAMPLE_TO_BALANCE: - logger.info("Balancing classes to reduce amount of data") + log.info("Balancing classes to reduce amount of data") x_train, y_train = _sample_to_balance(x_train, y_train) x_test, y_test = _sample_to_balance(x_test, y_test) - logger.info("Saving to HDF5 in a compressed format...") + log.info("Saving to HDF5 in a compressed format...") PROCESSED_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) with h5py.File(PROCESSED_DATA_FILENAME, "w") as f: f.create_dataset("x_train", data=x_train, dtype="u1", compression="lzf") @@ -164,7 +163,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: f.create_dataset("x_test", data=x_test, dtype="u1", compression="lzf") f.create_dataset("y_test", data=y_test, dtype="u1", compression="lzf") - logger.info("Saving essential dataset parameters to text_recognizer/datasets...") + log.info("Saving essential dataset parameters to text_recognizer/datasets...") mapping = {int(k): chr(v) for k, v in data["dataset"]["mapping"][0, 0]} characters = _augment_emnist_characters(mapping.values()) essentials = {"characters": characters, "input_shape": list(x_train.shape[1:])} @@ -172,7 +171,7 @@ def _process_raw_dataset(filename: str, dirname: Path) -> None: with ESSENTIALS_FILENAME.open(mode="w") as f: json.dump(essentials, f) - logger.info("Cleaning up...") + log.info("Cleaning up...") shutil.rmtree("matlab") os.chdir(curdir) diff --git a/text_recognizer/data/emnist_lines.py b/text_recognizer/data/emnist_lines.py index 7548ad5..5298726 100644 --- a/text_recognizer/data/emnist_lines.py +++ b/text_recognizer/data/emnist_lines.py @@ -1,11 +1,11 @@ """Dataset of generated text from EMNIST characters.""" from collections import defaultdict from pathlib import Path -from typing import Callable, Dict, Tuple +from typing import Callable, List, Tuple import attr import h5py -from loguru import logger +from loguru import logger as log import numpy as np import torch from torchvision import transforms @@ -46,8 +46,7 @@ class EMNISTLines(BaseDataModule): emnist: EMNIST = attr.ib(init=False, default=None) def __attrs_post_init__(self) -> None: - self.emnist = EMNIST() - self.mapping = self.emnist.mapping + self.emnist = EMNIST(mapping=self.mapping) max_width = ( int(self.emnist.dims[2] * (self.max_length + 1) * (1 - self.min_overlap)) @@ -86,7 +85,7 @@ class EMNISTLines(BaseDataModule): self._generate_data("test") def setup(self, stage: str = None) -> None: - logger.info("EMNISTLinesDataset loading data from HDF5...") + log.info("EMNISTLinesDataset loading data from HDF5...") if stage == "fit" or stage is None: print(self.data_filename) with h5py.File(self.data_filename, "r") as f: @@ -137,7 +136,7 @@ class EMNISTLines(BaseDataModule): return basic + data def _generate_data(self, split: str) -> None: - logger.info(f"EMNISTLines generating data for {split}...") + log.info(f"EMNISTLines generating data for {split}...") sentence_generator = SentenceGenerator( self.max_length - 2 ) # Subtract by 2 because start/end token @@ -148,17 +147,17 @@ class EMNISTLines(BaseDataModule): if split == "train": samples_by_char = _get_samples_by_char( - emnist.x_train, emnist.y_train, emnist.mapping + emnist.x_train, emnist.y_train, self.mapping.mapping ) num = self.num_train elif split == "val": samples_by_char = _get_samples_by_char( - emnist.x_train, emnist.y_train, emnist.mapping + emnist.x_train, emnist.y_train, self.mapping.mapping ) num = self.num_val else: samples_by_char = _get_samples_by_char( - emnist.x_test, emnist.y_test, emnist.mapping + emnist.x_test, emnist.y_test, self.mapping.mapping ) num = self.num_test @@ -173,14 +172,14 @@ class EMNISTLines(BaseDataModule): self.dims, ) y = convert_strings_to_labels( - y, emnist.inverse_mapping, length=MAX_OUTPUT_LENGTH + y, self.mapping.inverse_mapping, length=MAX_OUTPUT_LENGTH ) f.create_dataset(f"x_{split}", data=x, dtype="u1", compression="lzf") f.create_dataset(f"y_{split}", data=y, dtype="u1", compression="lzf") def _get_samples_by_char( - samples: np.ndarray, labels: np.ndarray, mapping: Dict + samples: np.ndarray, labels: np.ndarray, mapping: List ) -> defaultdict: samples_by_char = defaultdict(list) for sample, label in zip(samples, labels): diff --git a/text_recognizer/data/iam.py b/text_recognizer/data/iam.py index 3982c4f..7278eb2 100644 --- a/text_recognizer/data/iam.py +++ b/text_recognizer/data/iam.py @@ -7,7 +7,7 @@ import zipfile import attr from boltons.cacheutils import cachedproperty -from loguru import logger +from loguru import logger as log import toml from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info @@ -92,7 +92,7 @@ class IAM(BaseDataModule): def _extract_raw_dataset(filename: Path, dirname: Path) -> None: - logger.info("Extracting IAM data...") + log.info("Extracting IAM data...") curdir = os.getcwd() os.chdir(dirname) with zipfile.ZipFile(filename, "r") as f: diff --git a/text_recognizer/data/iam_extended_paragraphs.py b/text_recognizer/data/iam_extended_paragraphs.py index 0e97801..ccf0759 100644 --- a/text_recognizer/data/iam_extended_paragraphs.py +++ b/text_recognizer/data/iam_extended_paragraphs.py @@ -4,7 +4,6 @@ from typing import Dict, List import attr from torch.utils.data import ConcatDataset -from text_recognizer.data.base_dataset import BaseDataset from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info from text_recognizer.data.iam_paragraphs import IAMParagraphs from text_recognizer.data.iam_synthetic_paragraphs import IAMSyntheticParagraphs @@ -20,6 +19,7 @@ class IAMExtendedParagraphs(BaseDataModule): def __attrs_post_init__(self) -> None: self.iam_paragraphs = IAMParagraphs( + mapping=self.mapping, batch_size=self.batch_size, num_workers=self.num_workers, train_fraction=self.train_fraction, @@ -27,6 +27,7 @@ class IAMExtendedParagraphs(BaseDataModule): word_pieces=self.word_pieces, ) self.iam_synthetic_paragraphs = IAMSyntheticParagraphs( + mapping=self.mapping, batch_size=self.batch_size, num_workers=self.num_workers, train_fraction=self.train_fraction, @@ -36,7 +37,6 @@ class IAMExtendedParagraphs(BaseDataModule): self.dims = self.iam_paragraphs.dims self.output_dims = self.iam_paragraphs.output_dims - self.num_classes = self.iam_paragraphs.num_classes def prepare_data(self) -> None: """Prepares the paragraphs data.""" @@ -58,7 +58,7 @@ class IAMExtendedParagraphs(BaseDataModule): """Returns info about the dataset.""" basic = ( "IAM Original and Synthetic Paragraphs Dataset\n" # pylint: disable=no-member - f"Num classes: {len(self.num_classes)}\n" + f"Num classes: {len(self.mapping)}\n" f"Dims: {self.dims}\n" f"Output dims: {self.output_dims}\n" ) diff --git a/text_recognizer/data/iam_lines.py b/text_recognizer/data/iam_lines.py index b7f3fdd..1c63729 100644 --- a/text_recognizer/data/iam_lines.py +++ b/text_recognizer/data/iam_lines.py @@ -2,15 +2,14 @@ If not created, will generate a handwritten lines dataset from the IAM paragraphs dataset. - """ import json from pathlib import Path import random -from typing import Dict, List, Sequence, Tuple +from typing import List, Sequence, Tuple import attr -from loguru import logger +from loguru import logger as log from PIL import Image, ImageFile, ImageOps import numpy as np from torch import Tensor @@ -23,7 +22,7 @@ from text_recognizer.data.base_dataset import ( split_dataset, ) from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info -from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.iam import IAM from text_recognizer.data import image_utils @@ -48,17 +47,13 @@ class IAMLines(BaseDataModule): ) output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1)) - def __attrs_post_init__(self) -> None: - # TODO: refactor this - self.mapping, self.inverse_mapping, _ = emnist_mapping() - def prepare_data(self) -> None: """Creates the IAM lines dataset if not existing.""" if PROCESSED_DATA_DIRNAME.exists(): return - logger.info("Cropping IAM lines regions...") - iam = IAM() + log.info("Cropping IAM lines regions...") + iam = IAM(mapping=EmnistMapping()) iam.prepare_data() crops_train, labels_train = line_crops_and_labels(iam, "train") crops_test, labels_test = line_crops_and_labels(iam, "test") @@ -66,7 +61,7 @@ class IAMLines(BaseDataModule): shapes = np.array([crop.size for crop in crops_train + crops_test]) aspect_ratios = shapes[:, 0] / shapes[:, 1] - logger.info("Saving images, labels, and statistics...") + log.info("Saving images, labels, and statistics...") save_images_and_labels( crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME ) @@ -91,7 +86,7 @@ class IAMLines(BaseDataModule): raise ValueError("Target length longer than max output length.") y_train = convert_strings_to_labels( - labels_train, self.inverse_mapping, length=self.output_dims[0] + labels_train, self.mapping.inverse_mapping, length=self.output_dims[0] ) data_train = BaseDataset( x_train, y_train, transform=get_transform(IMAGE_WIDTH, self.augment) @@ -110,7 +105,7 @@ class IAMLines(BaseDataModule): raise ValueError("Taget length longer than max output length.") y_test = convert_strings_to_labels( - labels_test, self.inverse_mapping, length=self.output_dims[0] + labels_test, self.mapping.inverse_mapping, length=self.output_dims[0] ) self.data_test = BaseDataset( x_test, y_test, transform=get_transform(IMAGE_WIDTH) diff --git a/text_recognizer/data/iam_paragraphs.py b/text_recognizer/data/iam_paragraphs.py index 0f3a2ce..6189f7d 100644 --- a/text_recognizer/data/iam_paragraphs.py +++ b/text_recognizer/data/iam_paragraphs.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import Dict, List, Optional, Sequence, Tuple import attr -from loguru import logger +from loguru import logger as log import numpy as np from PIL import Image, ImageOps import torchvision.transforms as T @@ -17,9 +17,8 @@ from text_recognizer.data.base_dataset import ( split_dataset, ) from text_recognizer.data.base_data_module import BaseDataModule, load_and_print_info -from text_recognizer.data.emnist import emnist_mapping +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.iam import IAM -from text_recognizer.data.mappings import WordPieceMapping from text_recognizer.data.transforms import WordPiece @@ -38,7 +37,6 @@ MAX_LABEL_LENGTH = 682 class IAMParagraphs(BaseDataModule): """IAM handwriting database paragraphs.""" - num_classes: int = attr.ib() word_pieces: bool = attr.ib(default=False) augment: bool = attr.ib(default=True) train_fraction: float = attr.ib(default=0.8) @@ -46,21 +44,17 @@ class IAMParagraphs(BaseDataModule): init=False, default=(1, IMAGE_HEIGHT, IMAGE_WIDTH) ) output_dims: Tuple[int, int] = attr.ib(init=False, default=(MAX_LABEL_LENGTH, 1)) - inverse_mapping: Dict[str, int] = attr.ib(init=False) - - def __attrs_post_init__(self) -> None: - _, self.inverse_mapping, _ = emnist_mapping(extra_symbols=[NEW_LINE_TOKEN]) def prepare_data(self) -> None: """Create data for training/testing.""" if PROCESSED_DATA_DIRNAME.exists(): return - logger.info( + log.info( "Cropping IAM paragraph regions and saving them along with labels..." ) - iam = IAM() + iam = IAM(mapping=EmnistMapping()) iam.prepare_data() properties = {} @@ -89,7 +83,7 @@ class IAMParagraphs(BaseDataModule): crops, labels = _load_processed_crops_and_labels(split) data = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops] targets = convert_strings_to_labels( - strings=labels, mapping=self.inverse_mapping, length=self.output_dims[0] + strings=labels, mapping=self.mapping.inverse_mapping, length=self.output_dims[0] ) return BaseDataset( data, @@ -98,7 +92,7 @@ class IAMParagraphs(BaseDataModule): target_transform=get_target_transform(self.word_pieces), ) - logger.info(f"Loading IAM paragraph regions and lines for {stage}...") + log.info(f"Loading IAM paragraph regions and lines for {stage}...") _validate_data_dims(input_dims=self.dims, output_dims=self.output_dims) if stage == "fit" or stage is None: diff --git a/text_recognizer/data/iam_preprocessor.py b/text_recognizer/data/iam_preprocessor.py index 93a13bb..bcd77b4 100644 --- a/text_recognizer/data/iam_preprocessor.py +++ b/text_recognizer/data/iam_preprocessor.py @@ -1,18 +1,16 @@ """Preprocessor for extracting word letters from the IAM dataset. The code is mostly stolen from: - https://github.com/facebookresearch/gtn_applications/blob/master/datasets/iamdb.py - """ import collections import itertools from pathlib import Path import re -from typing import List, Optional, Union, Sequence +from typing import List, Optional, Union, Set import click -from loguru import logger +from loguru import logger as log import torch @@ -57,7 +55,7 @@ class Preprocessor: lexicon_path: Optional[Union[str, Path]] = None, use_words: bool = False, prepend_wordsep: bool = False, - special_tokens: Optional[Sequence[str]] = None, + special_tokens: Optional[Set[str]] = None, ) -> None: self.wordsep = "▁" self._use_word = use_words @@ -186,7 +184,7 @@ def cli( / "iam" / "iamdb" ) - logger.debug(f"Using data dir: {data_dir}") + log.debug(f"Using data dir: {data_dir}") if not data_dir.exists(): raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") else: @@ -196,15 +194,15 @@ def cli( preprocessor.extract_train_text() processed_dir = data_dir.parents[2] / "processed" / "iam_lines" - logger.debug(f"Saving processed files at: {processed_dir}") + log.debug(f"Saving processed files at: {processed_dir}") if save_text is not None: - logger.info("Saving training text") + log.info("Saving training text") with open(processed_dir / save_text, "w") as f: f.write("\n".join(t for t in preprocessor.text)) if save_tokens is not None: - logger.info("Saving tokens") + log.info("Saving tokens") with open(processed_dir / save_tokens, "w") as f: f.write("\n".join(preprocessor.tokens)) diff --git a/text_recognizer/data/iam_synthetic_paragraphs.py b/text_recognizer/data/iam_synthetic_paragraphs.py index f00a494..c938f8b 100644 --- a/text_recognizer/data/iam_synthetic_paragraphs.py +++ b/text_recognizer/data/iam_synthetic_paragraphs.py @@ -3,7 +3,7 @@ import random from typing import Any, List, Sequence, Tuple import attr -from loguru import logger +from loguru import logger as log import numpy as np from PIL import Image @@ -21,6 +21,7 @@ from text_recognizer.data.iam_paragraphs import ( IMAGE_SCALE_FACTOR, resize_image, ) +from text_recognizer.data.mappings import EmnistMapping from text_recognizer.data.iam import IAM from text_recognizer.data.iam_lines import ( line_crops_and_labels, @@ -43,10 +44,10 @@ class IAMSyntheticParagraphs(IAMParagraphs): if PROCESSED_DATA_DIRNAME.exists(): return - logger.info("Preparing IAM lines for synthetic paragraphs dataset.") - logger.info("Cropping IAM line regions and loading labels.") + log.info("Preparing IAM lines for synthetic paragraphs dataset.") + log.info("Cropping IAM line regions and loading labels.") - iam = IAM() + iam = IAM(mapping=EmnistMapping()) iam.prepare_data() crops_train, labels_train = line_crops_and_labels(iam, "train") @@ -55,7 +56,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): crops_train = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_train] crops_test = [resize_image(crop, IMAGE_SCALE_FACTOR) for crop in crops_test] - logger.info(f"Saving images and labels at {PROCESSED_DATA_DIRNAME}") + log.info(f"Saving images and labels at {PROCESSED_DATA_DIRNAME}") save_images_and_labels( crops_train, labels_train, "train", PROCESSED_DATA_DIRNAME ) @@ -64,7 +65,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): def setup(self, stage: str = None) -> None: """Loading synthetic dataset.""" - logger.info(f"IAM Synthetic dataset steup for stage {stage}...") + log.info(f"IAM Synthetic dataset steup for stage {stage}...") if stage == "fit" or stage is None: line_crops, line_labels = load_line_crops_and_labels( @@ -76,7 +77,7 @@ class IAMSyntheticParagraphs(IAMParagraphs): targets = convert_strings_to_labels( strings=paragraphs_labels, - mapping=self.inverse_mapping, + mapping=self.mapping.inverse_mapping, length=self.output_dims[0], ) self.data_train = BaseDataset( @@ -144,7 +145,7 @@ def generate_synthetic_paragraphs( [line_labels[i] for i in paragraph_indices] ) if len(paragraph_label) > paragraphs_properties["label_length"]["max"]: - logger.info( + log.info( "Label longer than longest label in original IAM paragraph dataset - hence dropping." ) continue @@ -158,7 +159,7 @@ def generate_synthetic_paragraphs( paragraph_crop.height > max_paragraph_shape[0] or paragraph_crop.width > max_paragraph_shape[1] ): - logger.info( + log.info( "Crop larger than largest crop in original IAM paragraphs dataset - hence dropping" ) continue diff --git a/text_recognizer/data/make_wordpieces.py b/text_recognizer/data/make_wordpieces.py index ef9eb1b..40fbee4 100644 --- a/text_recognizer/data/make_wordpieces.py +++ b/text_recognizer/data/make_wordpieces.py @@ -10,7 +10,7 @@ from pathlib import Path from typing import List, Optional, Union import click -from loguru import logger +from loguru import logger as log import sentencepiece as spm from text_recognizer.data.iam_preprocessor import load_metadata @@ -63,9 +63,9 @@ def save_pieces( vocab: set, ) -> None: """Saves word pieces to disk.""" - logger.info(f"Generating word piece list of size {num_pieces}.") + log.info(f"Generating word piece list of size {num_pieces}.") pieces = [sp.id_to_piece(i) for i in range(1, num_pieces + 1)] - logger.info(f"Encoding vocabulary of size {len(vocab)}.") + log.info(f"Encoding vocabulary of size {len(vocab)}.") encoded_vocab = [sp.encode_as_pieces(v) for v in vocab] # Save pieces to file. @@ -101,7 +101,7 @@ def cli( data_dir = ( Path(__file__).resolve().parents[2] / "data" / "processed" / "iam_lines" ) - logger.debug(f"Using data dir: {data_dir}") + log.debug(f"Using data dir: {data_dir}") if not data_dir.exists(): raise RuntimeError(f"Could not locate iamdb directory at {data_dir}") else: diff --git a/text_recognizer/data/mappings.py b/text_recognizer/data/mappings.py index b69e888..d1c64dd 100644 --- a/text_recognizer/data/mappings.py +++ b/text_recognizer/data/mappings.py @@ -1,18 +1,30 @@ """Mapping to and from word pieces.""" from abc import ABC, abstractmethod from pathlib import Path -from typing import Dict, List, Optional, Union, Set, Sequence +from typing import Dict, List, Optional, Union, Set import attr -import loguru.logger as log import torch +from loguru import logger as log from torch import Tensor from text_recognizer.data.emnist import emnist_mapping from text_recognizer.data.iam_preprocessor import Preprocessor +@attr.s class AbstractMapping(ABC): + input_size: List[int] = attr.ib(init=False) + mapping: List[str] = attr.ib(init=False) + inverse_mapping: Dict[str, int] = attr.ib(init=False) + + def __len__(self) -> int: + return len(self.mapping) + + @property + def num_classes(self) -> int: + return self.__len__() + @abstractmethod def get_token(self, *args, **kwargs) -> str: ... @@ -30,15 +42,13 @@ class AbstractMapping(ABC): ... -@attr.s +@attr.s(auto_attribs=True) class EmnistMapping(AbstractMapping): - extra_symbols: Optional[Set[str]] = attr.ib(default=None, converter=set) - mapping: Sequence[str] = attr.ib(init=False) - inverse_mapping: Dict[str, int] = attr.ib(init=False) - input_size: List[int] = attr.ib(init=False) + extra_symbols: Optional[Set[str]] = attr.ib(default=None) def __attrs_post_init__(self) -> None: """Post init configuration.""" + self.extra_symbols = set(self.extra_symbols) if self.extra_symbols is not None else None self.mapping, self.inverse_mapping, self.input_size = emnist_mapping( self.extra_symbols ) diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index caf63c1..8ce5c37 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -12,7 +12,7 @@ from torch import Tensor import torchmetrics -@attr.s +@attr.s(eq=False) class BaseLitModel(LightningModule): """Abstract PyTorch Lightning class.""" diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py index 0eb42dc..f83c9e4 100644 --- a/text_recognizer/models/metrics.py +++ b/text_recognizer/models/metrics.py @@ -8,11 +8,11 @@ from torch import Tensor from torchmetrics import Metric -@attr.s +@attr.s(eq=False) class CharacterErrorRate(Metric): """Character error rate metric, computed using Levenshtein distance.""" - ignore_indices: Set = attr.ib(converter=set) + ignore_indices: Set[Tensor] = attr.ib(converter=set) error: Tensor = attr.ib(init=False) total: Tensor = attr.ib(init=False) diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 0e01bb5..91e088d 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,5 +1,5 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Sequence, Tuple, Type +from typing import Tuple, Type, Set import attr import torch @@ -10,20 +10,20 @@ from text_recognizer.models.metrics import CharacterErrorRate from text_recognizer.models.base import BaseLitModel -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, eq=False) class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - mapping: Type[AbstractMapping] = attr.ib() - start_token: str = attr.ib() - end_token: str = attr.ib() - pad_token: str = attr.ib() + mapping: Type[AbstractMapping] = attr.ib(default=None) + start_token: str = attr.ib(default="") + end_token: str = attr.ib(default="") + pad_token: str = attr.ib(default="

") start_index: Tensor = attr.ib(init=False) end_index: Tensor = attr.ib(init=False) pad_index: Tensor = attr.ib(init=False) - ignore_indices: Sequence[str] = attr.ib(init=False) + ignore_indices: Set[Tensor] = attr.ib(init=False) val_cer: CharacterErrorRate = attr.ib(init=False) test_cer: CharacterErrorRate = attr.ib(init=False) @@ -32,7 +32,7 @@ class TransformerLitModel(BaseLitModel): self.start_index = self.mapping.get_index(self.start_token) self.end_index = self.mapping.get_index(self.end_token) self.pad_index = self.mapping.get_index(self.pad_token) - self.ignore_indices = [self.start_index, self.end_index, self.pad_index] + self.ignore_indices = set([self.start_index, self.end_index, self.pad_index]) self.val_cer = CharacterErrorRate(self.ignore_indices) self.test_cer = CharacterErrorRate(self.ignore_indices) diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index e215e14..22da018 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -10,7 +10,7 @@ import wandb from text_recognizer.models.base import BaseLitModel -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, eq=False) class VQVAELitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" diff --git a/text_recognizer/networks/conv_transformer.py b/text_recognizer/networks/conv_transformer.py index 7371be4..09cc654 100644 --- a/text_recognizer/networks/conv_transformer.py +++ b/text_recognizer/networks/conv_transformer.py @@ -13,7 +13,7 @@ from text_recognizer.networks.transformer.positional_encodings import ( ) -@attr.s +@attr.s(eq=False) class ConvTransformer(nn.Module): """Convolutional encoder and transformer decoder network.""" @@ -121,6 +121,7 @@ class ConvTransformer(nn.Module): Returns: Tensor: Sequence of word piece embeddings. """ + context = context.long() context_mask = context != self.pad_index context = self.token_embedding(context) * math.sqrt(self.hidden_dim) context = self.token_pos_encoder(context) diff --git a/text_recognizer/networks/encoders/efficientnet/efficientnet.py b/text_recognizer/networks/encoders/efficientnet/efficientnet.py index a36150a..b8eb53b 100644 --- a/text_recognizer/networks/encoders/efficientnet/efficientnet.py +++ b/text_recognizer/networks/encoders/efficientnet/efficientnet.py @@ -1,4 +1,4 @@ -"""Efficient net.""" +"""Efficientnet backbone.""" from typing import Tuple import attr @@ -12,8 +12,10 @@ from .utils import ( ) -@attr.s +@attr.s(eq=False) class EfficientNet(nn.Module): + """Efficientnet without classification head.""" + def __attrs_pre_init__(self) -> None: super().__init__() @@ -47,11 +49,13 @@ class EfficientNet(nn.Module): @arch.validator def check_arch(self, attribute: attr._make.Attribute, value: str) -> None: + """Validates the efficientnet architecure.""" if value not in self.archs: raise ValueError(f"{value} not a valid architecure.") self.params = self.archs[value] def _build(self) -> None: + """Builds the efficientnet backbone.""" _block_args = block_args() in_channels = 1 # BW out_channels = round_filters(32, self.params) @@ -73,8 +77,9 @@ class EfficientNet(nn.Module): for args in _block_args: args.in_channels = round_filters(args.in_channels, self.params) args.out_channels = round_filters(args.out_channels, self.params) - args.num_repeats = round_repeats(args.num_repeats, self.params) - for _ in range(args.num_repeats): + num_repeats = round_repeats(args.num_repeats, self.params) + del args.num_repeats + for _ in range(num_repeats): self._blocks.append( MBConvBlock( **args, bn_momentum=self.bn_momentum, bn_eps=self.bn_eps, @@ -93,6 +98,7 @@ class EfficientNet(nn.Module): ) def extract_features(self, x: Tensor) -> Tensor: + """Extracts the final feature map layer.""" x = self._conv_stem(x) for i, block in enumerate(self._blocks): stochastic_dropout_rate = self.stochastic_dropout_rate @@ -103,4 +109,5 @@ class EfficientNet(nn.Module): return x def forward(self, x: Tensor) -> Tensor: + """Returns efficientnet image features.""" return self.extract_features(x) diff --git a/text_recognizer/networks/encoders/efficientnet/mbconv.py b/text_recognizer/networks/encoders/efficientnet/mbconv.py index 3aa63d0..e85df87 100644 --- a/text_recognizer/networks/encoders/efficientnet/mbconv.py +++ b/text_recognizer/networks/encoders/efficientnet/mbconv.py @@ -1,76 +1,62 @@ """Mobile inverted residual block.""" -from typing import Any, Optional, Union, Tuple +from typing import Optional, Sequence, Union, Tuple +import attr import torch from torch import nn, Tensor import torch.nn.functional as F -from .utils import stochastic_depth +from text_recognizer.networks.encoders.efficientnet.utils import stochastic_depth +def _convert_stride(stride: Union[Tuple[int, int], int]) -> Tuple[int, int]: + """Converts int to tuple.""" + return ( + (stride,) * 2 if isinstance(stride, int) else stride + ) + + +@attr.s(eq=False) class MBConvBlock(nn.Module): """Mobile Inverted Residual Bottleneck block.""" - def __init__( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: Union[Tuple[int, int], int], - bn_momentum: float, - bn_eps: float, - se_ratio: float, - expand_ratio: int, - *args: Any, - **kwargs: Any, - ) -> None: + def __attrs_pre_init__(self) -> None: super().__init__() - self.kernel_size = kernel_size - self.stride = (stride,) * 2 if isinstance(stride, int) else stride - self.bn_momentum = bn_momentum - self.bn_eps = bn_eps - self.in_channels = in_channels - self.out_channels = out_channels + in_channels: int = attr.ib() + out_channels: int = attr.ib() + kernel_size: Tuple[int, int] = attr.ib() + stride: Tuple[int, int] = attr.ib(converter=_convert_stride) + bn_momentum: float = attr.ib() + bn_eps: float = attr.ib() + se_ratio: float = attr.ib() + expand_ratio: int = attr.ib() + pad: Tuple[int, int, int, int] = attr.ib(init=False) + _inverted_bottleneck: nn.Sequential = attr.ib(init=False) + _depthwise: nn.Sequential = attr.ib(init=False) + _squeeze_excite: nn.Sequential = attr.ib(init=False) + _pointwise: nn.Sequential = attr.ib(init=False) + + @pad.default + def _configure_padding(self) -> Tuple[int, int, int, int]: + """Set padding for convolutional layers.""" if self.stride == (2, 2): - self.pad = [ + return ( (self.kernel_size - 1) // 2 - 1, (self.kernel_size - 1) // 2, - ] * 2 - else: - self.pad = [(self.kernel_size - 1) // 2] * 4 - - # Placeholders for layers. - self._inverted_bottleneck: nn.Sequential = None - self._depthwise: nn.Sequential = None - self._squeeze_excite: nn.Sequential = None - self._pointwise: nn.Sequential = None - - self._build( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, - expand_ratio=expand_ratio, - se_ratio=se_ratio, - ) + ) * 2 + return ((self.kernel_size - 1) // 2,) * 4 + + def __attrs_post_init__(self) -> None: + """Post init configuration.""" + self._build() - def _build( - self, - in_channels: int, - out_channels: int, - kernel_size: int, - stride: Union[Tuple[int, int], int], - expand_ratio: int, - se_ratio: float, - ) -> None: - has_se = se_ratio is not None and 0.0 < se_ratio < 1.0 - inner_channels = in_channels * expand_ratio + def _build(self) -> None: + has_se = self.se_ratio is not None and 0.0 < self.se_ratio < 1.0 + inner_channels = self.in_channels * self.expand_ratio self._inverted_bottleneck = ( - self._configure_inverted_bottleneck( - in_channels=in_channels, out_channels=inner_channels, - ) - if expand_ratio != 1 + self._configure_inverted_bottleneck(out_channels=inner_channels) + if self.expand_ratio != 1 else None ) @@ -78,31 +64,23 @@ class MBConvBlock(nn.Module): in_channels=inner_channels, out_channels=inner_channels, groups=inner_channels, - kernel_size=kernel_size, - stride=stride, ) self._squeeze_excite = ( self._configure_squeeze_excite( - in_channels=inner_channels, - out_channels=inner_channels, - se_ratio=se_ratio, + in_channels=inner_channels, out_channels=inner_channels, ) if has_se else None ) - self._pointwise = self._configure_pointwise( - in_channels=inner_channels, out_channels=out_channels - ) + self._pointwise = self._configure_pointwise(in_channels=inner_channels) - def _configure_inverted_bottleneck( - self, in_channels: int, out_channels: int, - ) -> nn.Sequential: + def _configure_inverted_bottleneck(self, out_channels: int) -> nn.Sequential: """Expansion phase.""" return nn.Sequential( nn.Conv2d( - in_channels=in_channels, + in_channels=self.in_channels, out_channels=out_channels, kernel_size=1, bias=False, @@ -114,19 +92,14 @@ class MBConvBlock(nn.Module): ) def _configure_depthwise( - self, - in_channels: int, - out_channels: int, - groups: int, - kernel_size: int, - stride: Union[Tuple[int, int], int], + self, in_channels: int, out_channels: int, groups: int, ) -> nn.Sequential: return nn.Sequential( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, - kernel_size=kernel_size, - stride=stride, + kernel_size=self.kernel_size, + stride=self.stride, groups=groups, bias=False, ), @@ -137,9 +110,9 @@ class MBConvBlock(nn.Module): ) def _configure_squeeze_excite( - self, in_channels: int, out_channels: int, se_ratio: float + self, in_channels: int, out_channels: int ) -> nn.Sequential: - num_squeezed_channels = max(1, int(in_channels * se_ratio)) + num_squeezed_channels = max(1, int(in_channels * self.se_ratio)) return nn.Sequential( nn.Conv2d( in_channels=in_channels, @@ -154,18 +127,18 @@ class MBConvBlock(nn.Module): ), ) - def _configure_pointwise( - self, in_channels: int, out_channels: int - ) -> nn.Sequential: + def _configure_pointwise(self, in_channels: int) -> nn.Sequential: return nn.Sequential( nn.Conv2d( in_channels=in_channels, - out_channels=out_channels, + out_channels=self.out_channels, kernel_size=1, bias=False, ), nn.BatchNorm2d( - num_features=out_channels, momentum=self.bn_momentum, eps=self.bn_eps + num_features=self.out_channels, + momentum=self.bn_momentum, + eps=self.bn_eps, ), ) @@ -186,8 +159,8 @@ class MBConvBlock(nn.Module): residual = x if self._inverted_bottleneck is not None: x = self._inverted_bottleneck(x) - x = F.pad(x, self.pad) + x = F.pad(x, self.pad) x = self._depthwise(x) if self._squeeze_excite is not None: diff --git a/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index 9202cce..37ce29e 100644 --- a/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py @@ -15,7 +15,7 @@ from text_recognizer.networks.transformer.positional_encodings.rotary_embedding ) -@attr.s +@attr.s(eq=False) class Attention(nn.Module): """Standard attention.""" @@ -31,7 +31,6 @@ class Attention(nn.Module): dropout: nn.Dropout = attr.ib(init=False) fc: nn.Linear = attr.ib(init=False) qkv_fn: nn.Sequential = attr.ib(init=False) - attn_fn: F.softmax = attr.ib(init=False, default=F.softmax) def __attrs_post_init__(self) -> None: """Post init configuration.""" @@ -80,7 +79,7 @@ class Attention(nn.Module): else k_mask ) q_mask = rearrange(q_mask, "b i -> b () i ()") - k_mask = rearrange(k_mask, "b i -> b () () j") + k_mask = rearrange(k_mask, "b j -> b () () j") return q_mask * k_mask return @@ -129,7 +128,7 @@ class Attention(nn.Module): if self.causal: energy = self._apply_causal_mask(energy, mask, mask_value, device) - attn = self.attn_fn(energy, dim=-1) + attn = F.softmax(energy, dim=-1) attn = self.dropout(attn) out = einsum("b h i j, b h j d -> b h i d", attn, v) out = rearrange(out, "b h n d -> b n (h d)") diff --git a/text_recognizer/networks/transformer/layers.py b/text_recognizer/networks/transformer/layers.py index 66c9c50..ce443e5 100644 --- a/text_recognizer/networks/transformer/layers.py +++ b/text_recognizer/networks/transformer/layers.py @@ -12,7 +12,7 @@ from text_recognizer.networks.transformer.positional_encodings.rotary_embedding from text_recognizer.networks.util import load_partial_fn -@attr.s +@attr.s(eq=False) class AttentionLayers(nn.Module): """Standard transfomer layer.""" @@ -101,11 +101,11 @@ class AttentionLayers(nn.Module): return x -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, eq=False) class Encoder(AttentionLayers): causal: bool = attr.ib(default=False, init=False) -@attr.s(auto_attribs=True) +@attr.s(auto_attribs=True, eq=False) class Decoder(AttentionLayers): causal: bool = attr.ib(default=True, init=False) diff --git a/training/conf/callbacks/wandb.yaml b/training/conf/callbacks/wandb.yaml deleted file mode 100644 index 0017e11..0000000 --- a/training/conf/callbacks/wandb.yaml +++ /dev/null @@ -1,20 +0,0 @@ -defaults: - - default.yaml - -watch_model: - _target_: callbacks.wandb_callbacks.WatchModel - log: all - log_freq: 100 - -upload_code_as_artifact: - _target_: callbacks.wandb_callbacks.UploadCodeAsArtifact - project_dir: ${work_dir}/text_recognizer - -upload_ckpts_as_artifact: - _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact - ckpt_dir: checkpoints/ - upload_best_only: true - -log_text_predictions: - _target_: callbacks.wandb_callbacks.LogTextPredictions - num_samples: 8 diff --git a/training/conf/callbacks/wandb/checkpoints.yaml b/training/conf/callbacks/wandb/checkpoints.yaml new file mode 100644 index 0000000..a4a16ff --- /dev/null +++ b/training/conf/callbacks/wandb/checkpoints.yaml @@ -0,0 +1,4 @@ +upload_ckpts_as_artifact: + _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact + ckpt_dir: checkpoints/ + upload_best_only: true diff --git a/training/conf/callbacks/wandb/code.yaml b/training/conf/callbacks/wandb/code.yaml new file mode 100644 index 0000000..35f6ea3 --- /dev/null +++ b/training/conf/callbacks/wandb/code.yaml @@ -0,0 +1,3 @@ +upload_code_as_artifact: + _target_: callbacks.wandb_callbacks.UploadCodeAsArtifact + project_dir: ${work_dir}/text_recognizer diff --git a/training/conf/callbacks/wandb/image_reconstructions.yaml b/training/conf/callbacks/wandb/image_reconstructions.yaml new file mode 100644 index 0000000..e69de29 diff --git a/training/conf/callbacks/wandb/ocr_predictions.yaml b/training/conf/callbacks/wandb/ocr_predictions.yaml new file mode 100644 index 0000000..573fa96 --- /dev/null +++ b/training/conf/callbacks/wandb/ocr_predictions.yaml @@ -0,0 +1,3 @@ +log_text_predictions: + _target_: callbacks.wandb_callbacks.LogTextPredictions + num_samples: 8 diff --git a/training/conf/callbacks/wandb/watch.yaml b/training/conf/callbacks/wandb/watch.yaml new file mode 100644 index 0000000..511608c --- /dev/null +++ b/training/conf/callbacks/wandb/watch.yaml @@ -0,0 +1,4 @@ +watch_model: + _target_: callbacks.wandb_callbacks.WatchModel + log: all + log_freq: 100 diff --git a/training/conf/callbacks/wandb_ocr.yaml b/training/conf/callbacks/wandb_ocr.yaml new file mode 100644 index 0000000..efa3dda --- /dev/null +++ b/training/conf/callbacks/wandb_ocr.yaml @@ -0,0 +1,6 @@ +defaults: + - default + - wandb/watch + - wandb/code + - wandb/checkpoints + - wandb/ocr_predictions diff --git a/training/conf/config.yaml b/training/conf/config.yaml index a8e718e..93215ed 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -1,19 +1,17 @@ defaults: - - network: vqvae - - criterion: mse - - optimizer: madgrad - - lr_scheduler: one_cycle - - model: lit_vqvae + - callbacks: wandb_ocr + - criterion: label_smoothing - dataset: iam_extended_paragraphs + - hydra: default + - lr_scheduler: one_cycle + - mapping: word_piece + - model: lit_transformer + - network: conv_transformer + - optimizer: madgrad - trainer: default - - callbacks: - - checkpoint - - learning_rate_monitor seed: 4711 -wandb: false tune: false train: true test: true -load_checkpoint: null logging: INFO diff --git a/training/conf/criterion/label_smoothing.yaml b/training/conf/criterion/label_smoothing.yaml index ee47c59..13daba8 100644 --- a/training/conf/criterion/label_smoothing.yaml +++ b/training/conf/criterion/label_smoothing.yaml @@ -1,4 +1,4 @@ -_target_: text_recognizer.criterion.label_smoothing +_target_: text_recognizer.criterion.label_smoothing.LabelSmoothingLoss label_smoothing: 0.1 vocab_size: 1006 ignore_index: 1002 diff --git a/training/conf/hydra/default.yaml b/training/conf/hydra/default.yaml new file mode 100644 index 0000000..dfd9721 --- /dev/null +++ b/training/conf/hydra/default.yaml @@ -0,0 +1,6 @@ +# output paths for hydra logs +run: + dir: logs/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} +sweep: + dir: logs/multiruns/${now:%Y-%m-%d_%H-%M-%S} + subdir: ${hydra.job.num} diff --git a/training/conf/mapping/word_piece.yaml b/training/conf/mapping/word_piece.yaml new file mode 100644 index 0000000..3792523 --- /dev/null +++ b/training/conf/mapping/word_piece.yaml @@ -0,0 +1,9 @@ +_target_: text_recognizer.data.mappings.WordPieceMapping +num_features: 1000 +tokens: iamdb_1kwp_tokens_1000.txt +lexicon: iamdb_1kwp_lex_1000.txt +data_dir: null +use_words: false +prepend_wordsep: false +special_tokens: [ , ,

] +extra_symbols: [ \n ] diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml index 5341d8e..6ffde4e 100644 --- a/training/conf/model/lit_transformer.yaml +++ b/training/conf/model/lit_transformer.yaml @@ -1,8 +1,5 @@ -defaults: - - mapping: word_piece - _target_: text_recognizer.models.transformer.TransformerLitModel -interval: null +interval: step monitor: val/loss ignore_tokens: [ , ,

] start_token: diff --git a/training/conf/model/mapping/word_piece.yaml b/training/conf/model/mapping/word_piece.yaml deleted file mode 100644 index 39e2ba4..0000000 --- a/training/conf/model/mapping/word_piece.yaml +++ /dev/null @@ -1,9 +0,0 @@ -_target_: text_recognizer.data.mappings.WordPieceMapping -num_features: 1000 -tokens: iamdb_1kwp_tokens_1000.txt -lexicon: iamdb_1kwp_lex_1000.txt -data_dir: null -use_words: false -prepend_wordsep: false -special_tokens: ["", "", "

"] -extra_symbols: ["\n"] diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index 7d57a2d..a97157d 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -4,7 +4,7 @@ defaults: _target_: text_recognizer.networks.conv_transformer.ConvTransformer input_dims: [1, 576, 640] -hidden_dim: 256 +hidden_dim: 96 dropout_rate: 0.2 max_output_len: 451 num_classes: 1006 diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml index 3122de1..90b9d8a 100644 --- a/training/conf/network/decoder/transformer_decoder.yaml +++ b/training/conf/network/decoder/transformer_decoder.yaml @@ -2,12 +2,12 @@ defaults: - rotary_emb: null _target_: text_recognizer.networks.transformer.Decoder -dim: 256 +dim: 96 depth: 2 num_heads: 8 attn_fn: text_recognizer.networks.transformer.attention.Attention attn_kwargs: - dim_head: 64 + dim_head: 16 dropout_rate: 0.2 norm_fn: torch.nn.LayerNorm ff_fn: text_recognizer.networks.transformer.mlp.FeedForward diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml index 5ed6552..c665adc 100644 --- a/training/conf/trainer/default.yaml +++ b/training/conf/trainer/default.yaml @@ -6,6 +6,10 @@ gradient_clip_val: 0 fast_dev_run: false gpus: 1 precision: 16 -max_epochs: 64 +max_epochs: 512 terminate_on_nan: true weights_summary: top +limit_train_batches: 1.0 +limit_val_batches: 1.0 +limit_test_batches: 1.0 +resume_from_checkpoint: null diff --git a/training/run.py b/training/run.py index d88a8f6..30479c6 100644 --- a/training/run.py +++ b/training/run.py @@ -2,7 +2,7 @@ from typing import List, Optional, Type import hydra -import loguru.logger as log +from loguru import logger as log from omegaconf import DictConfig from pytorch_lightning import ( Callback, @@ -12,6 +12,7 @@ from pytorch_lightning import ( Trainer, ) from pytorch_lightning.loggers import LightningLoggerBase +from text_recognizer.data.mappings import AbstractMapping from torch import nn import utils @@ -25,15 +26,19 @@ def run(config: DictConfig) -> Optional[float]: if config.get("seed"): seed_everything(config.seed) + log.info(f"Instantiating mapping <{config.mapping._target_}>") + mapping: AbstractMapping = hydra.utils.instantiate(config.mapping) + log.info(f"Instantiating datamodule <{config.datamodule._target_}>") - datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) + datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule, mapping=mapping) log.info(f"Instantiating network <{config.network._target_}>") - network: nn.Module = hydra.utils.instantiate(config.network, **datamodule.config()) + network: nn.Module = hydra.utils.instantiate(config.network) log.info(f"Instantiating model <{config.model._target_}>") model: LightningModule = hydra.utils.instantiate( **config.model, + mapping=mapping, network=network, criterion_config=config.criterion, optimizer_config=config.optimizer, diff --git a/training/utils.py b/training/utils.py index 564b9bb..ef74f61 100644 --- a/training/utils.py +++ b/training/utils.py @@ -3,7 +3,7 @@ from typing import Any, List, Type import warnings import hydra -import loguru.logger as log +from loguru import logger as log from omegaconf import DictConfig, OmegaConf from pytorch_lightning import ( Callback, -- cgit v1.2.3-70-g09d2