From d3afa310f77f47553586eeee58e3d3345a754e2c Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 4 Aug 2021 05:03:51 +0200 Subject: New VQVAE --- notebooks/00-scratch-pad.ipynb | 220 +++++++++++- notebooks/05c-test-model-end-to-end.ipynb | 367 +++++++++------------ text_recognizer/models/vqvae.py | 16 +- text_recognizer/networks/vq_transformer.py | 77 +++++ text_recognizer/networks/vqvae/__init__.py | 3 - text_recognizer/networks/vqvae/decoder.py | 164 +++------ text_recognizer/networks/vqvae/encoder.py | 176 +++------- text_recognizer/networks/vqvae/quantizer.py | 142 ++++++++ text_recognizer/networks/vqvae/residual.py | 18 + text_recognizer/networks/vqvae/vector_quantizer.py | 119 ------- text_recognizer/networks/vqvae/vqvae.py | 122 ++++--- training/callbacks/wandb_callbacks.py | 69 ++-- .../callbacks/wandb_image_reconstructions.yaml | 3 + training/conf/callbacks/wandb_vae.yaml | 6 + training/conf/config.yaml | 2 + training/conf/experiment/vqvae.yaml | 20 ++ training/conf/experiment/vqvae_experiment.yaml | 13 - training/conf/model/lit_vqvae.yaml | 4 +- training/conf/network/conv_transformer.yaml | 2 +- .../conf/network/decoder/transformer_decoder.yaml | 4 +- training/conf/network/vqvae.yaml | 21 +- 21 files changed, 893 insertions(+), 675 deletions(-) create mode 100644 text_recognizer/networks/vq_transformer.py create mode 100644 text_recognizer/networks/vqvae/quantizer.py create mode 100644 text_recognizer/networks/vqvae/residual.py delete mode 100644 text_recognizer/networks/vqvae/vector_quantizer.py create mode 100644 training/conf/callbacks/wandb_vae.yaml create mode 100644 training/conf/experiment/vqvae.yaml delete mode 100644 training/conf/experiment/vqvae_experiment.yaml diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb index a193107..9f056bc 100644 --- a/notebooks/00-scratch-pad.ipynb +++ b/notebooks/00-scratch-pad.ipynb @@ -27,6 +27,209 @@ "from text_recognizer.networks.transformer.layers import Decoder" ] }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "t = torch.randint(0, 5, (4, 4))" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "36" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "576 // 16" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "40" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "640 // 16" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1440" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "36 * 40" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0, 1, 2, 1],\n", + " [1, 2, 3, 3],\n", + " [2, 2, 3, 3],\n", + " [4, 0, 2, 4]])" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "t = torch.randint(0, 5, (1, 4, 4, 4))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[[2, 3, 3, 3],\n", + " [3, 4, 4, 2],\n", + " [2, 3, 0, 0],\n", + " [4, 3, 4, 0]],\n", + "\n", + " [[3, 0, 3, 0],\n", + " [1, 4, 1, 3],\n", + " [2, 3, 3, 3],\n", + " [2, 3, 3, 1]],\n", + "\n", + " [[1, 1, 0, 3],\n", + " [1, 3, 0, 4],\n", + " [3, 1, 4, 2],\n", + " [3, 1, 4, 3]],\n", + "\n", + " [[3, 2, 3, 4],\n", + " [3, 2, 3, 3],\n", + " [0, 2, 2, 3],\n", + " [4, 0, 3, 4]]]])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 4, 16])" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t.flatten(start_dim=2).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[[2, 3, 3, 3, 3, 4, 4, 2, 2, 3, 0, 0, 4, 3, 4, 0],\n", + " [3, 0, 3, 0, 1, 4, 1, 3, 2, 3, 3, 3, 2, 3, 3, 1],\n", + " [1, 1, 0, 3, 1, 3, 0, 4, 3, 1, 4, 2, 3, 1, 4, 3],\n", + " [3, 2, 3, 4, 3, 2, 3, 3, 0, 2, 2, 3, 4, 0, 3, 4]]])" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t.flatten(start_dim=2)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "__init__() got an unexpected keyword argument 'dim'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_6532/3641656095.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mflatten\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFlatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m: __init__() got an unexpected keyword argument 'dim'" + ] + } + ], + "source": [ + "flatten = nn.Flatten(stdim=2)" + ] + }, { "cell_type": "code", "execution_count": 2, @@ -561,9 +764,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 65, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "TypeError", + "evalue": "__init__() missing 4 required positional arguments: 'attn_fn', 'norm_fn', 'ff_fn', and 'rotary_emb'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_9275/689714588.py\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdecoder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDecoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdepth\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_heads\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m8\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mff_kwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattn_kwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcross_attend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/layers.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0;34m\"causal\"\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"Cannot set causality on decoder\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 106\u001b[0;31m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcausal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mTypeError\u001b[0m: __init__() missing 4 required positional arguments: 'attn_fn', 'norm_fn', 'ff_fn', and 'rotary_emb'" + ] + } + ], "source": [ "decoder = Decoder(dim=128, depth=2, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)" ] diff --git a/notebooks/05c-test-model-end-to-end.ipynb b/notebooks/05c-test-model-end-to-end.ipynb index e3e92e2..850d205 100644 --- a/notebooks/05c-test-model-end-to-end.ipynb +++ b/notebooks/05c-test-model-end-to-end.ipynb @@ -2,19 +2,10 @@ "cells": [ { "cell_type": "code", - "execution_count": 4, + "execution_count": 1, "id": "1e40a88b", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", @@ -34,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 2, "id": "d3a6146b-94b1-4618-a4e4-00f8e23ffdb0", "metadata": {}, "outputs": [], @@ -47,67 +38,8 @@ { "cell_type": "code", "execution_count": 3, - "id": "6b722ca0-9c65-4f90-be4e-b7334ea81237", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "_target_: text_recognizer.models.transformer.TransformerLitModel\n", - "interval: step\n", - "monitor: val/loss\n", - "start_token: \n", - "end_token: \n", - "pad_token:

\n", - "\n", - "{'_target_': 'text_recognizer.models.transformer.TransformerLitModel', 'interval': 'step', 'monitor': 'val/loss', 'start_token': '', 'end_token': '', 'pad_token': '

'}\n" - ] - } - ], - "source": [ - "# context initialization\n", - "with initialize(config_path=\"../training/conf/model/\", job_name=\"test_app\"):\n", - " cfg = compose(config_name=\"lit_transformer\")\n", - " print(OmegaConf.to_yaml(cfg))\n", - " print(cfg)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5e6b49ce-7685-4491-bd0a-51487f06a237", - "metadata": {}, - "outputs": [], - "source": [ - "# context initialization\n", - "with initialize(config_path=\"../training/conf/mapping/\", job_name=\"test_app\"):\n", - " cfg = compose(config_name=\"word_piece\")\n", - " print(OmegaConf.to_yaml(cfg))\n", - " print(cfg)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9c797159-845e-42c6-bd65-1c976ad627cd", - "metadata": {}, - "outputs": [], - "source": [ - "# context initialization\n", - "with initialize(config_path=\"../training/conf/network/\", job_name=\"test_app\"):\n", - " cfg = compose(config_name=\"conv_transformer\")\n", - " print(OmegaConf.to_yaml(cfg))\n", - " print(cfg)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, "id": "764c8736-7d68-4261-a57d-face10ebbf42", - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [ { "name": "stdout", @@ -122,8 +54,7 @@ " mode: min\n", " verbose: false\n", " dirpath: checkpoints/\n", - " filename:\n", - " epoch:02d: null\n", + " filename: '{epoch:02d}'\n", " learning_rate_monitor:\n", " _target_: pytorch_lightning.callbacks.LearningRateMonitor\n", " logging_interval: step\n", @@ -139,20 +70,20 @@ " _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact\n", " ckpt_dir: checkpoints/\n", " upload_best_only: true\n", - " log_text_predictions:\n", - " _target_: callbacks.wandb_callbacks.LogTextPredictions\n", + " log_image_reconstruction:\n", + " _target_: callbacks.wandb_callbacks.LogReconstuctedImages\n", " num_samples: 8\n", "criterion:\n", - " _target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss\n", - " smoothing: 0.1\n", - " ignore_index: 1002\n", + " _target_: torch.nn.MSELoss\n", + " reduction: mean\n", "datamodule:\n", " _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs\n", - " batch_size: 8\n", + " batch_size: 32\n", " num_workers: 12\n", " train_fraction: 0.8\n", " augment: true\n", " pin_memory: false\n", + " word_pieces: true\n", "logger:\n", " wandb:\n", " _target_: pytorch_lightning.loggers.wandb.WandbLogger\n", @@ -170,8 +101,8 @@ " _target_: torch.optim.lr_scheduler.OneCycleLR\n", " max_lr: 0.001\n", " total_steps: null\n", - " epochs: 512\n", - " steps_per_epoch: 4992\n", + " epochs: 64\n", + " steps_per_epoch: 624\n", " pct_start: 0.3\n", " anneal_strategy: cos\n", " cycle_momentum: true\n", @@ -199,52 +130,21 @@ "\n", " '\n", "model:\n", - " _target_: text_recognizer.models.transformer.TransformerLitModel\n", + " _target_: text_recognizer.models.vqvae.VQVAELitModel\n", " interval: step\n", " monitor: val/loss\n", - " max_output_len: 451\n", - " start_token: \n", - " end_token: \n", - " pad_token:

\n", "network:\n", - " encoder:\n", - " _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet\n", - " arch: b0\n", - " out_channels: 1280\n", - " stochastic_dropout_rate: 0.2\n", - " bn_momentum: 0.99\n", - " bn_eps: 0.001\n", - " decoder:\n", - " _target_: text_recognizer.networks.transformer.Decoder\n", - " dim: 96\n", - " depth: 2\n", - " num_heads: 8\n", - " attn_fn: text_recognizer.networks.transformer.attention.Attention\n", - " attn_kwargs:\n", - " dim_head: 16\n", - " dropout_rate: 0.2\n", - " norm_fn: torch.nn.LayerNorm\n", - " ff_fn: text_recognizer.networks.transformer.mlp.FeedForward\n", - " ff_kwargs:\n", - " dim_out: null\n", - " expansion_factor: 4\n", - " glu: true\n", - " dropout_rate: 0.2\n", - " cross_attend: true\n", - " pre_norm: true\n", - " rotary_emb: null\n", - " _target_: text_recognizer.networks.conv_transformer.ConvTransformer\n", - " input_dims:\n", - " - 1\n", - " - 576\n", - " - 640\n", - " hidden_dim: 96\n", - " dropout_rate: 0.2\n", - " num_classes: 1006\n", - " pad_index: 1002\n", + " _target_: text_recognizer.networks.vqvae.VQVAE\n", + " in_channels: 1\n", + " res_channels: 32\n", + " num_residual_layers: 2\n", + " embedding_dim: 64\n", + " num_embeddings: 512\n", + " decay: 0.99\n", + " activation: mish\n", "optimizer:\n", " _target_: madgrad.MADGRAD\n", - " lr: 0.001\n", + " lr: 0.01\n", " momentum: 0.9\n", " weight_decay: 0\n", " eps: 1.0e-06\n", @@ -257,7 +157,7 @@ " fast_dev_run: false\n", " gpus: 1\n", " precision: 16\n", - " max_epochs: 512\n", + " max_epochs: 64\n", " terminate_on_nan: true\n", " weights_summary: top\n", " limit_train_batches: 1.0\n", @@ -269,91 +169,181 @@ "train: true\n", "test: true\n", "logging: INFO\n", + "work_dir: ${hydra:runtime.cwd}\n", "debug: false\n", + "print_config: true\n", + "ignore_warnings: true\n", "\n", - "{'callbacks': {'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'monitor': 'val/loss', 'save_top_k': 1, 'save_last': True, 'mode': 'min', 'verbose': False, 'dirpath': 'checkpoints/', 'filename': {'epoch:02d': None}}, 'learning_rate_monitor': {'_target_': 'pytorch_lightning.callbacks.LearningRateMonitor', 'logging_interval': 'step', 'log_momentum': False}, 'watch_model': {'_target_': 'callbacks.wandb_callbacks.WatchModel', 'log': 'all', 'log_freq': 100}, 'upload_code_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCodeAsArtifact', 'project_dir': '${work_dir}/text_recognizer'}, 'upload_ckpts_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCheckpointsAsArtifact', 'ckpt_dir': 'checkpoints/', 'upload_best_only': True}, 'log_text_predictions': {'_target_': 'callbacks.wandb_callbacks.LogTextPredictions', 'num_samples': 8}}, 'criterion': {'_target_': 'text_recognizer.criterions.label_smoothing.LabelSmoothingLoss', 'smoothing': 0.1, 'ignore_index': 1002}, 'datamodule': {'_target_': 'text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs', 'batch_size': 8, 'num_workers': 12, 'train_fraction': 0.8, 'augment': True, 'pin_memory': False}, 'logger': {'wandb': {'_target_': 'pytorch_lightning.loggers.wandb.WandbLogger', 'project': 'text-recognizer', 'name': None, 'save_dir': '.', 'offline': False, 'id': None, 'log_model': False, 'prefix': '', 'job_type': 'train', 'group': '', 'tags': []}}, 'lr_scheduler': {'_target_': 'torch.optim.lr_scheduler.OneCycleLR', 'max_lr': 0.001, 'total_steps': None, 'epochs': 512, 'steps_per_epoch': 4992, 'pct_start': 0.3, 'anneal_strategy': 'cos', 'cycle_momentum': True, 'base_momentum': 0.85, 'max_momentum': 0.95, 'div_factor': 25.0, 'final_div_factor': 10000.0, 'three_phase': True, 'last_epoch': -1, 'verbose': False}, 'mapping': {'_target_': 'text_recognizer.data.word_piece_mapping.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['', '', '

'], 'extra_symbols': ['\\n']}, 'model': {'_target_': 'text_recognizer.models.transformer.TransformerLitModel', 'interval': 'step', 'monitor': 'val/loss', 'max_output_len': 451, 'start_token': '', 'end_token': '', 'pad_token': '

'}, 'network': {'encoder': {'_target_': 'text_recognizer.networks.encoders.efficientnet.EfficientNet', 'arch': 'b0', 'out_channels': 1280, 'stochastic_dropout_rate': 0.2, 'bn_momentum': 0.99, 'bn_eps': 0.001}, 'decoder': {'_target_': 'text_recognizer.networks.transformer.Decoder', 'dim': 96, 'depth': 2, 'num_heads': 8, 'attn_fn': 'text_recognizer.networks.transformer.attention.Attention', 'attn_kwargs': {'dim_head': 16, 'dropout_rate': 0.2}, 'norm_fn': 'torch.nn.LayerNorm', 'ff_fn': 'text_recognizer.networks.transformer.mlp.FeedForward', 'ff_kwargs': {'dim_out': None, 'expansion_factor': 4, 'glu': True, 'dropout_rate': 0.2}, 'cross_attend': True, 'pre_norm': True, 'rotary_emb': None}, '_target_': 'text_recognizer.networks.conv_transformer.ConvTransformer', 'input_dims': [1, 576, 640], 'hidden_dim': 96, 'dropout_rate': 0.2, 'num_classes': 1006, 'pad_index': 1002}, 'optimizer': {'_target_': 'madgrad.MADGRAD', 'lr': 0.001, 'momentum': 0.9, 'weight_decay': 0, 'eps': 1e-06}, 'trainer': {'_target_': 'pytorch_lightning.Trainer', 'stochastic_weight_avg': False, 'auto_scale_batch_size': 'binsearch', 'auto_lr_find': False, 'gradient_clip_val': 0, 'fast_dev_run': False, 'gpus': 1, 'precision': 16, 'max_epochs': 512, 'terminate_on_nan': True, 'weights_summary': 'top', 'limit_train_batches': 1.0, 'limit_val_batches': 1.0, 'limit_test_batches': 1.0, 'resume_from_checkpoint': None}, 'seed': 4711, 'tune': False, 'train': True, 'test': True, 'logging': 'INFO', 'debug': False}\n" + "{'callbacks': {'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'monitor': 'val/loss', 'save_top_k': 1, 'save_last': True, 'mode': 'min', 'verbose': False, 'dirpath': 'checkpoints/', 'filename': '{epoch:02d}'}, 'learning_rate_monitor': {'_target_': 'pytorch_lightning.callbacks.LearningRateMonitor', 'logging_interval': 'step', 'log_momentum': False}, 'watch_model': {'_target_': 'callbacks.wandb_callbacks.WatchModel', 'log': 'all', 'log_freq': 100}, 'upload_code_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCodeAsArtifact', 'project_dir': '${work_dir}/text_recognizer'}, 'upload_ckpts_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCheckpointsAsArtifact', 'ckpt_dir': 'checkpoints/', 'upload_best_only': True}, 'log_image_reconstruction': {'_target_': 'callbacks.wandb_callbacks.LogReconstuctedImages', 'num_samples': 8}}, 'criterion': {'_target_': 'torch.nn.MSELoss', 'reduction': 'mean'}, 'datamodule': {'_target_': 'text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs', 'batch_size': 32, 'num_workers': 12, 'train_fraction': 0.8, 'augment': True, 'pin_memory': False, 'word_pieces': True}, 'logger': {'wandb': {'_target_': 'pytorch_lightning.loggers.wandb.WandbLogger', 'project': 'text-recognizer', 'name': None, 'save_dir': '.', 'offline': False, 'id': None, 'log_model': False, 'prefix': '', 'job_type': 'train', 'group': '', 'tags': []}}, 'lr_scheduler': {'_target_': 'torch.optim.lr_scheduler.OneCycleLR', 'max_lr': 0.001, 'total_steps': None, 'epochs': 64, 'steps_per_epoch': 624, 'pct_start': 0.3, 'anneal_strategy': 'cos', 'cycle_momentum': True, 'base_momentum': 0.85, 'max_momentum': 0.95, 'div_factor': 25.0, 'final_div_factor': 10000.0, 'three_phase': True, 'last_epoch': -1, 'verbose': False}, 'mapping': {'_target_': 'text_recognizer.data.word_piece_mapping.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['', '', '

'], 'extra_symbols': ['\\n']}, 'model': {'_target_': 'text_recognizer.models.vqvae.VQVAELitModel', 'interval': 'step', 'monitor': 'val/loss'}, 'network': {'_target_': 'text_recognizer.networks.vqvae.VQVAE', 'in_channels': 1, 'res_channels': 32, 'num_residual_layers': 2, 'embedding_dim': 64, 'num_embeddings': 512, 'decay': 0.99, 'activation': 'mish'}, 'optimizer': {'_target_': 'madgrad.MADGRAD', 'lr': 0.01, 'momentum': 0.9, 'weight_decay': 0, 'eps': 1e-06}, 'trainer': {'_target_': 'pytorch_lightning.Trainer', 'stochastic_weight_avg': False, 'auto_scale_batch_size': 'binsearch', 'auto_lr_find': False, 'gradient_clip_val': 0, 'fast_dev_run': False, 'gpus': 1, 'precision': 16, 'max_epochs': 64, 'terminate_on_nan': True, 'weights_summary': 'top', 'limit_train_batches': 1.0, 'limit_val_batches': 1.0, 'limit_test_batches': 1.0, 'resume_from_checkpoint': None}, 'seed': 4711, 'tune': False, 'train': True, 'test': True, 'logging': 'INFO', 'work_dir': '${hydra:runtime.cwd}', 'debug': False, 'print_config': True, 'ignore_warnings': True}\n" ] } ], "source": [ "# context initialization\n", "with initialize(config_path=\"../training/conf/\", job_name=\"test_app\"):\n", - " cfg = compose(config_name=\"config\")\n", + " cfg = compose(config_name=\"config\", overrides=[\"+experiment=vqvae\"])\n", " print(OmegaConf.to_yaml(cfg))\n", " print(cfg)" ] }, { "cell_type": "code", - "execution_count": 10, - "id": "9382f0ab-8760-4d59-b0b5-b8b65dd1ea31", + "execution_count": 4, + "id": "c1a9aa6b-6405-4ffe-b065-02340762476a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2021-08-04 04:49:04.188 | DEBUG | text_recognizer.data.word_piece_mapping:__init__:37 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n" + ] + } + ], + "source": [ + "mapping = instantiate(cfg.mapping)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "969ba3be-d78f-4b1e-b522-ea8a42669e86", + "metadata": {}, + "outputs": [], + "source": [ + "network = instantiate(cfg.network)" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "6147cd3e-0ad1-490f-917d-21be9bb8ce1c", + "metadata": {}, + "outputs": [], + "source": [ + "x = torch.rand(1, 1, 576, 640)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "a0ecea0c-abaf-4d5d-a13d-c085c1e4d282", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'monitor': 'val/loss', 'save_top_k': 1, 'save_last': True, 'mode': 'min', 'verbose': False, 'dirpath': 'checkpoints/', 'filename': {'epoch:02d': None}}, 'learning_rate_monitor': {'_target_': 'pytorch_lightning.callbacks.LearningRateMonitor', 'logging_interval': 'step', 'log_momentum': False}, 'watch_model': {'_target_': 'callbacks.wandb_callbacks.WatchModel', 'log': 'all', 'log_freq': 100}, 'upload_code_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCodeAsArtifact', 'project_dir': '${work_dir}/text_recognizer'}, 'upload_ckpts_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCheckpointsAsArtifact', 'ckpt_dir': 'checkpoints/', 'upload_best_only': True}, 'log_text_predictions': {'_target_': 'callbacks.wandb_callbacks.LogTextPredictions', 'num_samples': 8}}" + "torch.Size([1, 64, 144, 160])" ] }, - "execution_count": 10, + "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "cfg.get(\"callbacks\")" + "network.encode(x)[0].shape" ] }, { "cell_type": "code", - "execution_count": 12, - "id": "216d5680-66bf-4190-9401-1a59dbbc43af", + "execution_count": 38, + "id": "a7b9f249-7e5e-4f31-bbe1-cfd6d3701cf0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "pytorch_lightning.callbacks.ModelCheckpoint\n", - "pytorch_lightning.callbacks.LearningRateMonitor\n", - "callbacks.wandb_callbacks.WatchModel\n", - "callbacks.wandb_callbacks.UploadCodeAsArtifact\n", - "callbacks.wandb_callbacks.UploadCheckpointsAsArtifact\n", - "callbacks.wandb_callbacks.LogTextPredictions\n" + "torch.Size([512])\n", + "torch.Size([512])\n", + "torch.Size([512])\n", + "torch.Size([512])\n" ] + }, + { + "data": { + "text/plain": [ + "torch.Size([1, 1, 576, 640])" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "for l in cfg.callbacks.values():\n", - " print(l.get(\"_target_\"))" + "network(x)[0].shape" ] }, { "cell_type": "code", - "execution_count": 4, - "id": "c1a9aa6b-6405-4ffe-b065-02340762476a", + "execution_count": null, + "id": "23c9d90c-042b-423e-ab85-18449e29ded4", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2021-08-03 15:27:02.069 | DEBUG | text_recognizer.data.word_piece_mapping:__init__:37 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n" - ] - } - ], + "outputs": [], "source": [ - "mapping = instantiate(cfg.mapping)" + "576 / 4" ] }, { "cell_type": "code", - "execution_count": 5, - "id": "969ba3be-d78f-4b1e-b522-ea8a42669e86", + "execution_count": null, + "id": "047ebc09-1c74-44a7-a314-1099f09722fe", "metadata": {}, "outputs": [], "source": [ - "network = instantiate(cfg.network)" + "t = torch.randint(0, 1006, (1, 451)).cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "87372dde-2b1a-432b-ab79-0b116124c724", + "metadata": {}, + "outputs": [], + "source": [ + "z = torch.rand((1, 36 * 40, 128)).cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cf7ca9bf-cafa-4128-9db7-046c16933a52", + "metadata": {}, + "outputs": [], + "source": [ + "network = network.cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dfceaa5f-9ad8-4d33-addb-c56e8da48356", + "metadata": {}, + "outputs": [], + "source": [ + "network.decode(z, t).shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9105fbbb-4363-4d3e-a01e-bc519c3b9c3a", + "metadata": {}, + "outputs": [], + "source": [ + "decoder = decoder.cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c5797ec4-7a6a-46fd-8adc-265df44d0341", + "metadata": {}, + "outputs": [], + "source": [ + "decoder(z, t).shape" ] }, { @@ -368,11 +358,9 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "a6fae1fa-492d-4648-80fd-1c0dac659b02", - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "datamodule = instantiate(cfg.datamodule, mapping=mapping)" @@ -380,19 +368,10 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "514053ef-fcac-4f3c-a7c8-72c6927d6798", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2021-08-03 15:28:22.541 | INFO | text_recognizer.data.iam_paragraphs:setup:95 - Loading IAM paragraph regions and lines for None...\n", - "2021-08-03 15:28:45.280 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:68 - IAM Synthetic dataset steup for stage None...\n" - ] - } - ], + "outputs": [], "source": [ "datamodule.prepare_data()\n", "datamodule.setup()" @@ -400,21 +379,10 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "4bad950b-a197-4c60-ad89-903124659a98", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "4992" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "len(datamodule.train_dataloader())" ] @@ -431,7 +399,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "id": "f6e01c15-9a1b-4036-87ae-78716c592264", "metadata": {}, "outputs": [], @@ -441,7 +409,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "id": "4dc475fc-31f4-487e-88c8-b0f445131f5b", "metadata": {}, "outputs": [], @@ -451,7 +419,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "id": "c5c8ed64-d98c-47b5-baf2-1ba57a6c882f", "metadata": {}, "outputs": [], @@ -461,11 +429,9 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "id": "b5ff5b24-f804-402b-a8ab-f366443025ca", - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ " model = hydra.utils.instantiate(\n", @@ -481,21 +447,10 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "id": "99f8a39f-8b10-4f7d-8bff-52794fd48717", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - ">" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "mapping.get_index" ] @@ -514,9 +469,7 @@ "cell_type": "code", "execution_count": null, "id": "8f0742ad-5e2f-42d5-83e7-6e46398b4f0f", - "metadata": { - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [ "net" diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index 22da018..5890fd9 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -14,31 +14,33 @@ from text_recognizer.models.base import BaseLitModel class VQVAELitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" + latent_loss_weight: float = attr.ib(default=0.25) + def forward(self, data: Tensor) -> Tensor: """Forward pass with the transformer network.""" - return self.network.predict(data) + return self.network(data) def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: """Training step.""" data, _ = batch - reconstructions, vq_loss = self.network(data) + reconstructions, vq_loss = self(data) loss = self.loss_fn(reconstructions, data) - loss += vq_loss + loss += self.latent_loss_weight * vq_loss self.log("train/loss", loss) return loss def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Validation step.""" data, _ = batch - reconstructions, vq_loss = self.network(data) + reconstructions, vq_loss = self(data) loss = self.loss_fn(reconstructions, data) - loss += vq_loss + loss += self.latent_loss_weight * vq_loss self.log("val/loss", loss, prog_bar=True) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" data, _ = batch - reconstructions, vq_loss = self.network(data) + reconstructions, vq_loss = self(data) loss = self.loss_fn(reconstructions, data) - loss += vq_loss + loss += self.latent_loss_weight * vq_loss self.log("test/loss", loss) diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py new file mode 100644 index 0000000..a972565 --- /dev/null +++ b/text_recognizer/networks/vq_transformer.py @@ -0,0 +1,77 @@ +"""Vector quantized encoder, transformer decoder.""" +import math +from typing import Tuple + +from torch import nn, Tensor + +from text_recognizer.networks.encoders.efficientnet import EfficientNet +from text_recognizer.networks.conv_transformer import ConvTransformer +from text_recognizer.networks.transformer.layers import Decoder +from text_recognizer.networks.transformer.positional_encodings import ( + PositionalEncoding, + PositionalEncoding2D, +) + + +class VqTransformer(ConvTransformer): + """Convolutional encoder and transformer decoder network.""" + + def __init__( + self, + input_dims: Tuple[int, int, int], + hidden_dim: int, + dropout_rate: float, + num_classes: int, + pad_index: Tensor, + encoder: EfficientNet, + decoder: Decoder, + ) -> None: + # TODO: Load pretrained vqvae encoder. + super().__init__( + input_dims=input_dims, + hidden_dim=hidden_dim, + dropout_rate=dropout_rate, + num_classes=num_classes, + pad_index=pad_index, + encoder=encoder, + decoder=decoder, + ) + # Latent projector for down sampling number of filters and 2d + # positional encoding. + self.latent_encoder = nn.Sequential( + nn.Conv2d( + in_channels=self.encoder.out_channels, + out_channels=self.hidden_dim, + kernel_size=1, + ), + PositionalEncoding2D( + hidden_dim=self.hidden_dim, + max_h=self.input_dims[1], + max_w=self.input_dims[2], + ), + nn.Flatten(start_dim=2), + ) + + def encode(self, x: Tensor) -> Tensor: + """Encodes an image into a latent feature vector. + + Args: + x (Tensor): Image tensor. + + Shape: + - x: :math: `(B, C, H, W)` + - z: :math: `(B, Sx, E)` + + where Sx is the length of the flattened feature maps projected from + the encoder. E latent dimension for each pixel in the projected + feature maps. + + Returns: + Tensor: A Latent embedding of the image. + """ + z = self.encoder(x) + z = self.latent_encoder(z) + + # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] + z = z.permute(0, 2, 1) + return z diff --git a/text_recognizer/networks/vqvae/__init__.py b/text_recognizer/networks/vqvae/__init__.py index 763953c..7d56bdb 100644 --- a/text_recognizer/networks/vqvae/__init__.py +++ b/text_recognizer/networks/vqvae/__init__.py @@ -1,5 +1,2 @@ """VQ-VAE module.""" -from .decoder import Decoder -from .encoder import Encoder -from .vector_quantizer import VectorQuantizer from .vqvae import VQVAE diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index 32de912..3f59f0d 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -1,133 +1,65 @@ """CNN decoder for the VQ-VAE.""" - -from typing import List, Optional, Tuple, Type - -import torch +import attr from torch import nn from torch import Tensor from text_recognizer.networks.util import activation_function -from text_recognizer.networks.vqvae.encoder import _ResidualBlock +from text_recognizer.networks.vqvae.residual import Residual +@attr.s(eq=False) class Decoder(nn.Module): """A CNN encoder network.""" - def __init__( - self, - channels: List[int], - kernel_sizes: List[int], - strides: List[int], - num_residual_layers: int, - embedding_dim: int, - upsampling: Optional[List[List[int]]] = None, - activation: str = "leaky_relu", - dropout_rate: float = 0.0, - ) -> None: - super().__init__() - - if dropout_rate: - if activation == "selu": - dropout = nn.AlphaDropout(p=dropout_rate) - else: - dropout = nn.Dropout(p=dropout_rate) - else: - dropout = None - - self.upsampling = upsampling - - self.res_block = nn.ModuleList([]) - self.upsampling_block = nn.ModuleList([]) - - self.embedding_dim = embedding_dim - activation = activation_function(activation) - - # Configure encoder. - self.decoder = self._build_decoder( - channels, kernel_sizes, strides, num_residual_layers, activation, dropout, - ) - - def _build_decompression_block( - self, - in_channels: int, - channels: int, - kernel_sizes: List[int], - strides: List[int], - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], - ) -> nn.ModuleList: - modules = nn.ModuleList([]) - configuration = zip(channels, kernel_sizes, strides) - for i, (out_channels, kernel_size, stride) in enumerate(configuration): - modules.append( - nn.Sequential( - nn.ConvTranspose2d( - in_channels, - out_channels, - kernel_size, - stride=stride, - padding=1, - ), - activation, - ) - ) - - if self.upsampling and i < len(self.upsampling): - modules.append(nn.Upsample(size=self.upsampling[i]),) + in_channels: int = attr.ib() + embedding_dim: int = attr.ib() + out_channels: int = attr.ib() + res_channels: int = attr.ib() + num_residual_layers: int = attr.ib() + activation: str = attr.ib() + decoder: nn.Sequential = attr.ib(init=False) - if dropout is not None: - modules.append(dropout) - - in_channels = out_channels - - modules.extend( - nn.Sequential( - nn.ConvTranspose2d( - in_channels, 1, kernel_size=kernel_size, stride=stride, padding=1 - ), - nn.Tanh(), - ) - ) - - return modules - - def _build_decoder( - self, - channels: int, - kernel_sizes: List[int], - strides: List[int], - num_residual_layers: int, - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], - ) -> nn.Sequential: - - self.res_block.append( - nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,) - ) + def __attrs_pre_init__(self) -> None: + super().__init__() - # Bottleneck module. - self.res_block.extend( - nn.ModuleList( - [ - _ResidualBlock(channels[0], channels[0], dropout) - for i in range(num_residual_layers) - ] + def __attrs_post_init__(self) -> None: + """Post init configuration.""" + self.decoder = self._build_decompression_block() + + def _build_decompression_block(self,) -> nn.Sequential: + activation_fn = activation_function(self.activation) + blocks = [ + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.embedding_dim, + kernel_size=3, + padding=1, ) - ) - - # Decompression module - self.upsampling_block.extend( - self._build_decompression_block( - channels[0], channels[1:], kernel_sizes, strides, activation, dropout + ] + for _ in range(self.num_residual_layers): + blocks.append( + Residual(in_channels=self.embedding_dim, out_channels=self.res_channels) ) - ) - - self.res_block = nn.Sequential(*self.res_block) - self.upsampling_block = nn.Sequential(*self.upsampling_block) - - return nn.Sequential(self.res_block, self.upsampling_block) + blocks.append(activation_fn) + blocks += [ + nn.ConvTranspose2d( + in_channels=self.embedding_dim, + out_channels=self.embedding_dim // 2, + kernel_size=4, + stride=2, + padding=1, + ), + activation_fn, + nn.ConvTranspose2d( + in_channels=self.embedding_dim // 2, + out_channels=self.out_channels, + kernel_size=4, + stride=2, + padding=1, + ), + ] + return nn.Sequential(*blocks) def forward(self, z_q: Tensor) -> Tensor: """Reconstruct input from given codes.""" - x_reconstruction = self.decoder(z_q) - return x_reconstruction + return self.decoder(z_q) diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index 65801df..e480545 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -1,147 +1,75 @@ """CNN encoder for the VQ-VAE.""" from typing import Sequence, Optional, Tuple, Type -import torch +import attr from torch import nn from torch import Tensor from text_recognizer.networks.util import activation_function -from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer - - -class _ResidualBlock(nn.Module): - def __init__( - self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]], - ) -> None: - super().__init__() - self.block = [ - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), - nn.ReLU(inplace=True), - nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False), - ] - - if dropout is not None: - self.block.append(dropout) - - self.block = nn.Sequential(*self.block) - - def forward(self, x: Tensor) -> Tensor: - """Apply the residual forward pass.""" - return x + self.block(x) +from text_recognizer.networks.vqvae.residual import Residual +@attr.s(eq=False) class Encoder(nn.Module): """A CNN encoder network.""" - def __init__( - self, - in_channels: int, - channels: Sequence[int], - kernel_sizes: Sequence[int], - strides: Sequence[int], - num_residual_layers: int, - embedding_dim: int, - num_embeddings: int, - beta: float = 0.25, - activation: str = "leaky_relu", - dropout_rate: float = 0.0, - ) -> None: - super().__init__() - - if dropout_rate: - if activation == "selu": - dropout = nn.AlphaDropout(p=dropout_rate) - else: - dropout = nn.Dropout(p=dropout_rate) - else: - dropout = None - - self.embedding_dim = embedding_dim - self.num_embeddings = num_embeddings - self.beta = beta - activation = activation_function(activation) - - # Configure encoder. - self.encoder = self._build_encoder( - in_channels, - channels, - kernel_sizes, - strides, - num_residual_layers, - activation, - dropout, - ) + in_channels: int = attr.ib() + out_channels: int = attr.ib() + res_channels: int = attr.ib() + num_residual_layers: int = attr.ib() + embedding_dim: int = attr.ib() + activation: str = attr.ib() + encoder: nn.Sequential = attr.ib(init=False) - # Configure Vector Quantizer. - self.vector_quantizer = VectorQuantizer( - self.num_embeddings, self.embedding_dim, self.beta - ) - - @staticmethod - def _build_compression_block( - in_channels: int, - channels: int, - kernel_sizes: Sequence[int], - strides: Sequence[int], - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], - ) -> nn.ModuleList: - modules = nn.ModuleList([]) - configuration = zip(channels, kernel_sizes, strides) - for out_channels, kernel_size, stride in configuration: - modules.append( - nn.Sequential( - nn.Conv2d( - in_channels, out_channels, kernel_size, stride=stride, padding=1 - ), - activation, - ) - ) - - if dropout is not None: - modules.append(dropout) - - in_channels = out_channels - - return modules + def __attrs_pre_init__(self) -> None: + super().__init__() - def _build_encoder( - self, - in_channels: int, - channels: int, - kernel_sizes: Sequence[int], - strides: Sequence[int], - num_residual_layers: int, - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], - ) -> nn.Sequential: - encoder = nn.ModuleList([]) + def __attrs_post_init__(self) -> None: + """Post init configuration.""" + self.encoder = self._build_compression_block() + + def _build_compression_block(self) -> nn.Sequential: + activation_fn = activation_function(self.activation) + block = [ + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels // 2, + kernel_size=4, + stride=2, + padding=1, + ), + activation_fn, + nn.Conv2d( + in_channels=self.out_channels // 2, + out_channels=self.out_channels, + kernel_size=4, + stride=2, + padding=1, + ), + activation_fn, + nn.Conv2d( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + padding=1, + ), + ] - # compression module - encoder.extend( - self._build_compression_block( - in_channels, channels, kernel_sizes, strides, activation, dropout + for _ in range(self.num_residual_layers): + block.append( + Residual(in_channels=self.out_channels, out_channels=self.res_channels) ) - ) - # Bottleneck module. - encoder.extend( - nn.ModuleList( - [ - _ResidualBlock(channels[-1], channels[-1], dropout) - for i in range(num_residual_layers) - ] + block.append( + nn.Conv2d( + in_channels=self.out_channels, + out_channels=self.embedding_dim, + kernel_size=1, ) ) - encoder.append( - nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,) - ) - - return nn.Sequential(*encoder) + return nn.Sequential(*block) def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Encodes input into a discrete representation.""" - z_e = self.encoder(x) - z_q, vq_loss = self.vector_quantizer(z_e) - return z_q, vq_loss + return self.encoder(x) diff --git a/text_recognizer/networks/vqvae/quantizer.py b/text_recognizer/networks/vqvae/quantizer.py new file mode 100644 index 0000000..5e0b602 --- /dev/null +++ b/text_recognizer/networks/vqvae/quantizer.py @@ -0,0 +1,142 @@ +"""Implementation of a Vector Quantized Variational AutoEncoder. + +Reference: +https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py +""" +from einops import rearrange +import torch +from torch import nn +from torch import Tensor +from torch.nn import functional as F + + +class EmbeddingEMA(nn.Module): + def __init__(self, num_embeddings: int, embedding_dim: int) -> None: + super().__init__() + weight = torch.zeros(num_embeddings, embedding_dim) + nn.init.kaiming_uniform_(weight, nonlinearity="linear") + self.register_buffer("weight", weight) + self.register_buffer("_cluster_size", torch.zeros(num_embeddings)) + self.register_buffer("_weight_avg", weight) + + +class VectorQuantizer(nn.Module): + """The codebook that contains quantized vectors.""" + + def __init__( + self, num_embeddings: int, embedding_dim: int, decay: float = 0.99 + ) -> None: + super().__init__() + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.decay = decay + self.embedding = EmbeddingEMA(self.num_embeddings, self.embedding_dim) + + def discretization_bottleneck(self, latent: Tensor) -> Tensor: + """Computes the code nearest to the latent representation. + + First we compute the posterior categorical distribution, and then map + the latent representation to the nearest element of the embedding. + + Args: + latent (Tensor): The latent representation. + + Shape: + - latent :math:`(B x H x W, D)` + + Returns: + Tensor: The quantized embedding vector. + + """ + # Store latent shape. + b, h, w, d = latent.shape + + # Flatten the latent representation to 2D. + latent = rearrange(latent, "b h w d -> (b h w) d") + + # Compute the L2 distance between the latents and the embeddings. + l2_distance = ( + torch.sum(latent ** 2, dim=1, keepdim=True) + + torch.sum(self.embedding.weight ** 2, dim=1) + - 2 * latent @ self.embedding.weight.t() + ) # [BHW x K] + + # Find the embedding k nearest to each latent. + encoding_indices = torch.argmin(l2_distance, dim=1).unsqueeze(1) # [BHW, 1] + + # Convert to one-hot encodings, aka discrete bottleneck. + one_hot_encoding = torch.zeros( + encoding_indices.shape[0], self.num_embeddings, device=latent.device + ) + one_hot_encoding.scatter_(1, encoding_indices, 1) # [BHW x K] + + # Embedding quantization. + quantized_latent = one_hot_encoding @ self.embedding.weight # [BHW, D] + quantized_latent = rearrange( + quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w + ) + if self.training: + self.compute_ema(one_hot_encoding=one_hot_encoding, latent=latent) + + return quantized_latent + + def compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None: + batch_cluster_size = one_hot_encoding.sum(axis=0) + batch_embedding_avg = (latent.t() @ one_hot_encoding).t() + print(batch_cluster_size.shape) + print(self.embedding._cluster_size.shape) + self.embedding._cluster_size.data.mul_(self.decay).add_( + batch_cluster_size, alpha=1 - self.decay + ) + self.embedding._weight_avg.data.mul_(self.decay).add_( + batch_embedding_avg, alpha=1 - self.decay + ) + new_embedding = self.embedding._weight_avg / ( + self.embedding._cluster_size + 1.0e-5 + ).unsqueeze(1) + self.embedding.weight.data.copy_(new_embedding) + + def vq_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor: + """Vector Quantization loss. + + The vector quantization algorithm allows us to create a codebook. The VQ + algorithm works by moving the embedding vectors towards the encoder outputs. + + The embedding loss moves the embedding vector towards the encoder outputs. The + .detach() works as the stop gradient (sg) described in the paper. + + Because the volume of the embedding space is dimensionless, it can arbitarily + grow if the embeddings are not trained as fast as the encoder parameters. To + mitigate this, a commitment loss is added in the second term which makes sure + that the encoder commits to an embedding and that its output does not grow. + + Args: + latent (Tensor): The encoder output. + quantized_latent (Tensor): The quantized latent. + + Returns: + Tensor: The combinded VQ loss. + + """ + commitment_loss = F.mse_loss(quantized_latent.detach(), latent) + # embedding_loss = F.mse_loss(quantized_latent, latent.detach()) + # return embedding_loss + self.beta * commitment_loss + return commitment_loss + + def forward(self, latent: Tensor) -> Tensor: + """Forward pass that returns the quantized vector and the vq loss.""" + # Rearrange latent representation s.t. the hidden dim is at the end. + latent = rearrange(latent, "b d h w -> b h w d") + + # Maps latent to the nearest code in the codebook. + quantized_latent = self.discretization_bottleneck(latent) + + loss = self.vq_loss(latent, quantized_latent) + + # Add residue to the quantized latent. + quantized_latent = latent + (quantized_latent - latent).detach() + + # Rearrange the quantized shape back to the original shape. + quantized_latent = rearrange(quantized_latent, "b h w d -> b d h w") + + return quantized_latent, loss diff --git a/text_recognizer/networks/vqvae/residual.py b/text_recognizer/networks/vqvae/residual.py new file mode 100644 index 0000000..98109b8 --- /dev/null +++ b/text_recognizer/networks/vqvae/residual.py @@ -0,0 +1,18 @@ +"""Residual block.""" +from torch import nn +from torch import Tensor + + +class Residual(nn.Module): + def __init__(self, in_channels: int, out_channels: int,) -> None: + super().__init__() + self.block = nn.Sequential( + nn.Mish(inplace=True), + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.Mish(inplace=True), + nn.Conv2d(out_channels, in_channels, kernel_size=1, bias=False), + ) + + def forward(self, x: Tensor) -> Tensor: + """Apply the residual forward pass.""" + return x + self.block(x) diff --git a/text_recognizer/networks/vqvae/vector_quantizer.py b/text_recognizer/networks/vqvae/vector_quantizer.py deleted file mode 100644 index f92c7ee..0000000 --- a/text_recognizer/networks/vqvae/vector_quantizer.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Implementation of a Vector Quantized Variational AutoEncoder. - -Reference: -https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py - -""" - -from einops import rearrange -import torch -from torch import nn -from torch import Tensor -from torch.nn import functional as F - - -class VectorQuantizer(nn.Module): - """The codebook that contains quantized vectors.""" - - def __init__( - self, num_embeddings: int, embedding_dim: int, beta: float = 0.25 - ) -> None: - super().__init__() - self.K = num_embeddings - self.D = embedding_dim - self.beta = beta - - self.embedding = nn.Embedding(self.K, self.D) - - # Initialize the codebook. - nn.init.uniform_(self.embedding.weight, -1 / self.K, 1 / self.K) - - def discretization_bottleneck(self, latent: Tensor) -> Tensor: - """Computes the code nearest to the latent representation. - - First we compute the posterior categorical distribution, and then map - the latent representation to the nearest element of the embedding. - - Args: - latent (Tensor): The latent representation. - - Shape: - - latent :math:`(B x H x W, D)` - - Returns: - Tensor: The quantized embedding vector. - - """ - # Store latent shape. - b, h, w, d = latent.shape - - # Flatten the latent representation to 2D. - latent = rearrange(latent, "b h w d -> (b h w) d") - - # Compute the L2 distance between the latents and the embeddings. - l2_distance = ( - torch.sum(latent ** 2, dim=1, keepdim=True) - + torch.sum(self.embedding.weight ** 2, dim=1) - - 2 * latent @ self.embedding.weight.t() - ) # [BHW x K] - - # Find the embedding k nearest to each latent. - encoding_indices = torch.argmin(l2_distance, dim=1).unsqueeze(1) # [BHW, 1] - - # Convert to one-hot encodings, aka discrete bottleneck. - one_hot_encoding = torch.zeros( - encoding_indices.shape[0], self.K, device=latent.device - ) - one_hot_encoding.scatter_(1, encoding_indices, 1) # [BHW x K] - - # Embedding quantization. - quantized_latent = one_hot_encoding @ self.embedding.weight # [BHW, D] - quantized_latent = rearrange( - quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w - ) - - return quantized_latent - - def vq_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor: - """Vector Quantization loss. - - The vector quantization algorithm allows us to create a codebook. The VQ - algorithm works by moving the embedding vectors towards the encoder outputs. - - The embedding loss moves the embedding vector towards the encoder outputs. The - .detach() works as the stop gradient (sg) described in the paper. - - Because the volume of the embedding space is dimensionless, it can arbitarily - grow if the embeddings are not trained as fast as the encoder parameters. To - mitigate this, a commitment loss is added in the second term which makes sure - that the encoder commits to an embedding and that its output does not grow. - - Args: - latent (Tensor): The encoder output. - quantized_latent (Tensor): The quantized latent. - - Returns: - Tensor: The combinded VQ loss. - - """ - embedding_loss = F.mse_loss(quantized_latent, latent.detach()) - commitment_loss = F.mse_loss(quantized_latent.detach(), latent) - return embedding_loss + self.beta * commitment_loss - - def forward(self, latent: Tensor) -> Tensor: - """Forward pass that returns the quantized vector and the vq loss.""" - # Rearrange latent representation s.t. the hidden dim is at the end. - latent = rearrange(latent, "b d h w -> b h w d") - - # Maps latent to the nearest code in the codebook. - quantized_latent = self.discretization_bottleneck(latent) - - loss = self.vq_loss(latent, quantized_latent) - - # Add residue to the quantized latent. - quantized_latent = latent + (quantized_latent - latent).detach() - - # Rearrange the quantized shape back to the original shape. - quantized_latent = rearrange(quantized_latent, "b h w d -> b d h w") - - return quantized_latent, loss diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py index 5aa929b..1585d40 100644 --- a/text_recognizer/networks/vqvae/vqvae.py +++ b/text_recognizer/networks/vqvae/vqvae.py @@ -1,10 +1,14 @@ """The VQ-VAE.""" -from typing import Any, Dict, List, Optional, Tuple +from typing import Tuple +import torch from torch import nn from torch import Tensor +import torch.nn.functional as F -from text_recognizer.networks.vqvae import Decoder, Encoder +from text_recognizer.networks.vqvae.decoder import Decoder +from text_recognizer.networks.vqvae.encoder import Encoder +from text_recognizer.networks.vqvae.quantizer import VectorQuantizer class VQVAE(nn.Module): @@ -13,62 +17,92 @@ class VQVAE(nn.Module): def __init__( self, in_channels: int, - channels: List[int], - kernel_sizes: List[int], - strides: List[int], + res_channels: int, num_residual_layers: int, embedding_dim: int, num_embeddings: int, - upsampling: Optional[List[List[int]]] = None, - beta: float = 0.25, - activation: str = "leaky_relu", - dropout_rate: float = 0.0, - *args: Any, - **kwargs: Dict, + decay: float = 0.99, + activation: str = "mish", ) -> None: super().__init__() + # Encoders + self.btm_encoder = Encoder( + in_channels=1, + out_channels=embedding_dim, + res_channels=res_channels, + num_residual_layers=num_residual_layers, + embedding_dim=embedding_dim, + activation=activation, + ) + + self.top_encoder = Encoder( + in_channels=embedding_dim, + out_channels=embedding_dim, + res_channels=res_channels, + num_residual_layers=num_residual_layers, + embedding_dim=embedding_dim, + activation=activation, + ) + + # Quantizers + self.btm_quantizer = VectorQuantizer( + num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay, + ) - # configure encoder. - self.encoder = Encoder( - in_channels, - channels, - kernel_sizes, - strides, - num_residual_layers, - embedding_dim, - num_embeddings, - beta, - activation, - dropout_rate, + self.top_quantizer = VectorQuantizer( + num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay, ) - # Configure decoder. - channels.reverse() - kernel_sizes.reverse() - strides.reverse() - self.decoder = Decoder( - channels, - kernel_sizes, - strides, - num_residual_layers, - embedding_dim, - upsampling, - activation, - dropout_rate, + # Decoders + self.top_decoder = Decoder( + in_channels=embedding_dim, + out_channels=embedding_dim, + embedding_dim=embedding_dim, + res_channels=res_channels, + num_residual_layers=num_residual_layers, + activation=activation, + ) + + self.btm_decoder = Decoder( + in_channels=2 * embedding_dim, + out_channels=in_channels, + embedding_dim=embedding_dim, + res_channels=res_channels, + num_residual_layers=num_residual_layers, + activation=activation, ) def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Encodes input to a latent code.""" - return self.encoder(x) + z_btm = self.btm_encoder(x) + z_top = self.top_encoder(z_btm) + return z_btm, z_top + + def quantize( + self, z_btm: Tensor, z_top: Tensor + ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + q_btm, vq_btm_loss = self.top_quantizer(z_btm) + q_top, vq_top_loss = self.top_quantizer(z_top) + return q_btm, vq_btm_loss, q_top, vq_top_loss - def decode(self, z_q: Tensor) -> Tensor: + def decode(self, q_btm: Tensor, q_top: Tensor) -> Tuple[Tensor, Tensor]: """Reconstructs input from latent codes.""" - return self.decoder(z_q) + d_top = self.top_decoder(q_top) + x_hat = self.btm_decoder(torch.cat((d_top, q_btm), dim=1)) + return d_top, x_hat + + def loss_fn( + self, vq_btm_loss: Tensor, vq_top_loss: Tensor, d_top: Tensor, z_btm: Tensor + ) -> Tensor: + """Calculates the latent loss.""" + return 0.5 * (vq_top_loss + vq_btm_loss) + F.mse_loss(d_top, z_btm) def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Compresses and decompresses input.""" - if len(x.shape) < 4: - x = x[(None,) * (4 - len(x.shape))] - z_q, vq_loss = self.encode(x) - x_reconstruction = self.decode(z_q) - return x_reconstruction, vq_loss + z_btm, z_top = self.encode(x) + q_btm, vq_btm_loss, q_top, vq_top_loss = self.quantize(z_btm=z_btm, z_top=z_top) + d_top, x_hat = self.decode(q_btm=q_btm, q_top=q_top) + vq_loss = self.loss_fn( + vq_btm_loss=vq_btm_loss, vq_top_loss=vq_top_loss, d_top=d_top, z_btm=z_btm + ) + return x_hat, vq_loss diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index 906531f..c750e4b 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -5,6 +5,7 @@ import wandb from pytorch_lightning import Callback, LightningModule, Trainer from pytorch_lightning.loggers import LoggerCollection, WandbLogger from pytorch_lightning.utilities import rank_zero_only +from torch.utils.data import DataLoader def get_wandb_logger(trainer: Trainer) -> WandbLogger: @@ -86,7 +87,11 @@ class LogTextPredictions(Callback): self.ready = False def _log_predictions( - self, stage: str, trainer: Trainer, pl_module: LightningModule + self, + stage: str, + trainer: Trainer, + pl_module: LightningModule, + dataloader: DataLoader, ) -> None: """Logs the predicted text contained in the images.""" if not self.ready: @@ -96,22 +101,20 @@ class LogTextPredictions(Callback): experiment = logger.experiment # Get a validation batch from the validation dataloader. - samples = next(iter(trainer.datamodule.val_dataloader())) + samples = next(iter(dataloader)) imgs, labels = samples imgs = imgs.to(device=pl_module.device) logits = pl_module(imgs) mapping = pl_module.mapping - columns = ["id", "image", "prediction", "truth"] + columns = ["image", "prediction", "truth"] data = [ - [id, wandb.Image(img), mapping.get_text(pred), mapping.get_text(label)] - for id, (img, pred, label) in enumerate( - zip( - imgs[: self.num_samples], - logits[: self.num_samples], - labels[: self.num_samples], - ) + [wandb.Image(img), mapping.get_text(pred), mapping.get_text(label)] + for img, pred, label in zip( + imgs[: self.num_samples], + logits[: self.num_samples], + labels[: self.num_samples], ) ] @@ -133,11 +136,17 @@ class LogTextPredictions(Callback): self, trainer: Trainer, pl_module: LightningModule ) -> None: """Logs predictions on validation epoch end.""" - self._log_predictions(stage="val", trainer=trainer, pl_module=pl_module) + dataloader = trainer.datamodule.val_dataloader() + self._log_predictions( + stage="val", trainer=trainer, pl_module=pl_module, dataloader=dataloader + ) def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Logs predictions on train epoch end.""" - self._log_predictions(stage="test", trainer=trainer, pl_module=pl_module) + dataloader = trainer.datamodule.test_dataloader() + self._log_predictions( + stage="test", trainer=trainer, pl_module=pl_module, dataloader=dataloader + ) class LogReconstuctedImages(Callback): @@ -148,7 +157,11 @@ class LogReconstuctedImages(Callback): self.ready = False def _log_reconstruction( - self, stage: str, trainer: Trainer, pl_module: LightningModule + self, + stage: str, + trainer: Trainer, + pl_module: LightningModule, + dataloader: DataLoader, ) -> None: """Logs the reconstructions.""" if not self.ready: @@ -158,20 +171,24 @@ class LogReconstuctedImages(Callback): experiment = logger.experiment # Get a validation batch from the validation dataloader. - samples = next(iter(trainer.datamodule.val_dataloader())) + samples = next(iter(dataloader)) imgs, _ = samples + colums = ["input", "reconstruction"] imgs = imgs.to(device=pl_module.device) - reconstructions = pl_module(imgs) + reconstructions = pl_module(imgs)[0] + data = [ + [wandb.Image(img), wandb.Image(rec)] + for img, rec in zip( + imgs[: self.num_samples], reconstructions[: self.num_samples] + ) + ] experiment.log( { - f"Reconstructions/{experiment.name}/{stage}": [ - [wandb.Image(img), wandb.Image(rec),] - for img, rec in zip( - imgs[: self.num_samples], reconstructions[: self.num_samples], - ) - ] + f"Reconstructions/{experiment.name}/{stage}": wandb.Table( + data=data, columns=colums + ) } ) @@ -189,8 +206,14 @@ class LogReconstuctedImages(Callback): self, trainer: Trainer, pl_module: LightningModule ) -> None: """Logs predictions on validation epoch end.""" - self._log_reconstruction(stage="val", trainer=trainer, pl_module=pl_module) + dataloader = trainer.datamodule.val_dataloader() + self._log_reconstruction( + stage="val", trainer=trainer, pl_module=pl_module, dataloader=dataloader + ) def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None: """Logs predictions on train epoch end.""" - self._log_reconstruction(stage="test", trainer=trainer, pl_module=pl_module) + dataloader = trainer.datamodule.test_dataloader() + self._log_reconstruction( + stage="test", trainer=trainer, pl_module=pl_module, dataloader=dataloader + ) diff --git a/training/conf/callbacks/wandb_image_reconstructions.yaml b/training/conf/callbacks/wandb_image_reconstructions.yaml index e69de29..6cc4ada 100644 --- a/training/conf/callbacks/wandb_image_reconstructions.yaml +++ b/training/conf/callbacks/wandb_image_reconstructions.yaml @@ -0,0 +1,3 @@ +log_image_reconstruction: + _target_: callbacks.wandb_callbacks.LogReconstuctedImages + num_samples: 8 diff --git a/training/conf/callbacks/wandb_vae.yaml b/training/conf/callbacks/wandb_vae.yaml new file mode 100644 index 0000000..609a8e8 --- /dev/null +++ b/training/conf/callbacks/wandb_vae.yaml @@ -0,0 +1,6 @@ +defaults: + - default + - wandb_watch + - wandb_code + - wandb_checkpoints + - wandb_image_reconstructions diff --git a/training/conf/config.yaml b/training/conf/config.yaml index 782bcbb..6b74502 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -1,3 +1,5 @@ +# @package _global_ + defaults: - callbacks: wandb_ocr - criterion: label_smoothing diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml new file mode 100644 index 0000000..13e5f34 --- /dev/null +++ b/training/conf/experiment/vqvae.yaml @@ -0,0 +1,20 @@ +# @package _global_ + +defaults: + - override /network: vqvae + - override /criterion: mse + - override /model: lit_vqvae + - override /callbacks: wandb_vae + +trainer: + max_epochs: 64 + +datamodule: + batch_size: 32 + +lr_scheduler: + epochs: 64 + steps_per_epoch: 624 + +optimizer: + lr: 1.0e-2 diff --git a/training/conf/experiment/vqvae_experiment.yaml b/training/conf/experiment/vqvae_experiment.yaml deleted file mode 100644 index 0858c3d..0000000 --- a/training/conf/experiment/vqvae_experiment.yaml +++ /dev/null @@ -1,13 +0,0 @@ -defaults: - - override /network: vqvae - - override /criterion: mse - - override /optimizer: madgrad - - override /lr_scheduler: one_cycle - - override /model: lit_vqvae - - override /dataset: iam_extended_paragraphs - - override /trainer: default - - override /callbacks: - - wandb - -load_checkpoint: null -logging: INFO diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml index b337fe6..8837573 100644 --- a/training/conf/model/lit_vqvae.yaml +++ b/training/conf/model/lit_vqvae.yaml @@ -1,2 +1,4 @@ _target_: text_recognizer.models.vqvae.VQVAELitModel -mapping: sentence_piece +interval: step +monitor: val/loss +latent_loss_weight: 0.25 diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index f76e892..d3a3b0f 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -4,7 +4,7 @@ defaults: _target_: text_recognizer.networks.conv_transformer.ConvTransformer input_dims: [1, 576, 640] -hidden_dim: 96 +hidden_dim: 128 dropout_rate: 0.2 num_classes: 1006 pad_index: 1002 diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml index eb80f64..c326c04 100644 --- a/training/conf/network/decoder/transformer_decoder.yaml +++ b/training/conf/network/decoder/transformer_decoder.yaml @@ -2,12 +2,12 @@ defaults: - rotary_emb: null _target_: text_recognizer.networks.transformer.Decoder -dim: 96 +dim: 128 depth: 2 num_heads: 8 attn_fn: text_recognizer.networks.transformer.attention.Attention attn_kwargs: - dim_head: 16 + dim_head: 64 dropout_rate: 0.2 norm_fn: torch.nn.LayerNorm ff_fn: text_recognizer.networks.transformer.mlp.FeedForward diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml index 22eebf8..5a5c066 100644 --- a/training/conf/network/vqvae.yaml +++ b/training/conf/network/vqvae.yaml @@ -1,13 +1,8 @@ -type: VQVAE -args: - in_channels: 1 - channels: [64, 96] - kernel_sizes: [4, 4] - strides: [2, 2] - num_residual_layers: 2 - embedding_dim: 64 - num_embeddings: 256 - upsampling: null - beta: 0.25 - activation: leaky_relu - dropout_rate: 0.2 +_target_: text_recognizer.networks.vqvae.VQVAE +in_channels: 1 +res_channels: 32 +num_residual_layers: 2 +embedding_dim: 64 +num_embeddings: 512 +decay: 0.99 +activation: mish -- cgit v1.2.3-70-g09d2