summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-04 05:03:51 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-04 05:03:51 +0200
commitd3afa310f77f47553586eeee58e3d3345a754e2c (patch)
tree08b7de1daf2550852d0a1e4d4d75202f14bb03d4
parent65d5f6c694e73792e40ed693a1381a792da8d277 (diff)
New VQVAE
-rw-r--r--notebooks/00-scratch-pad.ipynb220
-rw-r--r--notebooks/05c-test-model-end-to-end.ipynb367
-rw-r--r--text_recognizer/models/vqvae.py16
-rw-r--r--text_recognizer/networks/vq_transformer.py77
-rw-r--r--text_recognizer/networks/vqvae/__init__.py3
-rw-r--r--text_recognizer/networks/vqvae/decoder.py164
-rw-r--r--text_recognizer/networks/vqvae/encoder.py176
-rw-r--r--text_recognizer/networks/vqvae/quantizer.py (renamed from text_recognizer/networks/vqvae/vector_quantizer.py)51
-rw-r--r--text_recognizer/networks/vqvae/residual.py18
-rw-r--r--text_recognizer/networks/vqvae/vqvae.py122
-rw-r--r--training/callbacks/wandb_callbacks.py69
-rw-r--r--training/conf/callbacks/wandb_image_reconstructions.yaml3
-rw-r--r--training/conf/callbacks/wandb_vae.yaml6
-rw-r--r--training/conf/config.yaml2
-rw-r--r--training/conf/experiment/vqvae.yaml20
-rw-r--r--training/conf/experiment/vqvae_experiment.yaml13
-rw-r--r--training/conf/model/lit_vqvae.yaml4
-rw-r--r--training/conf/network/conv_transformer.yaml2
-rw-r--r--training/conf/network/decoder/transformer_decoder.yaml4
-rw-r--r--training/conf/network/vqvae.yaml21
20 files changed, 788 insertions, 570 deletions
diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb
index a193107..9f056bc 100644
--- a/notebooks/00-scratch-pad.ipynb
+++ b/notebooks/00-scratch-pad.ipynb
@@ -29,6 +29,209 @@
},
{
"cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t = torch.randint(0, 5, (4, 4))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "36"
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "576 // 16"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "40"
+ ]
+ },
+ "execution_count": 22,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "640 // 16"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "1440"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "36 * 40"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[0, 1, 2, 1],\n",
+ " [1, 2, 3, 3],\n",
+ " [2, 2, 3, 3],\n",
+ " [4, 0, 2, 4]])"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t = torch.randint(0, 5, (1, 4, 4, 4))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[[[2, 3, 3, 3],\n",
+ " [3, 4, 4, 2],\n",
+ " [2, 3, 0, 0],\n",
+ " [4, 3, 4, 0]],\n",
+ "\n",
+ " [[3, 0, 3, 0],\n",
+ " [1, 4, 1, 3],\n",
+ " [2, 3, 3, 3],\n",
+ " [2, 3, 3, 1]],\n",
+ "\n",
+ " [[1, 1, 0, 3],\n",
+ " [1, 3, 0, 4],\n",
+ " [3, 1, 4, 2],\n",
+ " [3, 1, 4, 3]],\n",
+ "\n",
+ " [[3, 2, 3, 4],\n",
+ " [3, 2, 3, 3],\n",
+ " [0, 2, 2, 3],\n",
+ " [4, 0, 3, 4]]]])"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 4, 16])"
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t.flatten(start_dim=2).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([[[2, 3, 3, 3, 3, 4, 4, 2, 2, 3, 0, 0, 4, 3, 4, 0],\n",
+ " [3, 0, 3, 0, 1, 4, 1, 3, 2, 3, 3, 3, 2, 3, 3, 1],\n",
+ " [1, 1, 0, 3, 1, 3, 0, 4, 3, 1, 4, 2, 3, 1, 4, 3],\n",
+ " [3, 2, 3, 4, 3, 2, 3, 3, 0, 2, 2, 3, 4, 0, 3, 4]]])"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "t.flatten(start_dim=2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "TypeError",
+ "evalue": "__init__() got an unexpected keyword argument 'dim'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m/tmp/ipykernel_6532/3641656095.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mflatten\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mFlatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;31mTypeError\u001b[0m: __init__() got an unexpected keyword argument 'dim'"
+ ]
+ }
+ ],
+ "source": [
+ "flatten = nn.Flatten(stdim=2)"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
@@ -561,9 +764,22 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 65,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "ename": "TypeError",
+ "evalue": "__init__() missing 4 required positional arguments: 'attn_fn', 'norm_fn', 'ff_fn', and 'rotary_emb'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
+ "\u001b[0;32m/tmp/ipykernel_9275/689714588.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdecoder\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mDecoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdim\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m128\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdepth\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_heads\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m8\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mff_kwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattn_kwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcross_attend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;32m~/projects/text-recognizer/text_recognizer/networks/transformer/layers.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 104\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 105\u001b[0m \u001b[0;32massert\u001b[0m \u001b[0;34m\"causal\"\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"Cannot set causality on decoder\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 106\u001b[0;31m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcausal\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
+ "\u001b[0;31mTypeError\u001b[0m: __init__() missing 4 required positional arguments: 'attn_fn', 'norm_fn', 'ff_fn', and 'rotary_emb'"
+ ]
+ }
+ ],
"source": [
"decoder = Decoder(dim=128, depth=2, num_heads=8, ff_kwargs={}, attn_kwargs={}, cross_attend=True)"
]
diff --git a/notebooks/05c-test-model-end-to-end.ipynb b/notebooks/05c-test-model-end-to-end.ipynb
index e3e92e2..850d205 100644
--- a/notebooks/05c-test-model-end-to-end.ipynb
+++ b/notebooks/05c-test-model-end-to-end.ipynb
@@ -2,19 +2,10 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 1,
"id": "1e40a88b",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "The autoreload extension is already loaded. To reload it, use:\n",
- " %reload_ext autoreload\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
@@ -34,7 +25,7 @@
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": 2,
"id": "d3a6146b-94b1-4618-a4e4-00f8e23ffdb0",
"metadata": {},
"outputs": [],
@@ -47,67 +38,8 @@
{
"cell_type": "code",
"execution_count": 3,
- "id": "6b722ca0-9c65-4f90-be4e-b7334ea81237",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "_target_: text_recognizer.models.transformer.TransformerLitModel\n",
- "interval: step\n",
- "monitor: val/loss\n",
- "start_token: <s>\n",
- "end_token: <e>\n",
- "pad_token: <p>\n",
- "\n",
- "{'_target_': 'text_recognizer.models.transformer.TransformerLitModel', 'interval': 'step', 'monitor': 'val/loss', 'start_token': '<s>', 'end_token': '<e>', 'pad_token': '<p>'}\n"
- ]
- }
- ],
- "source": [
- "# context initialization\n",
- "with initialize(config_path=\"../training/conf/model/\", job_name=\"test_app\"):\n",
- " cfg = compose(config_name=\"lit_transformer\")\n",
- " print(OmegaConf.to_yaml(cfg))\n",
- " print(cfg)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "5e6b49ce-7685-4491-bd0a-51487f06a237",
- "metadata": {},
- "outputs": [],
- "source": [
- "# context initialization\n",
- "with initialize(config_path=\"../training/conf/mapping/\", job_name=\"test_app\"):\n",
- " cfg = compose(config_name=\"word_piece\")\n",
- " print(OmegaConf.to_yaml(cfg))\n",
- " print(cfg)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "9c797159-845e-42c6-bd65-1c976ad627cd",
- "metadata": {},
- "outputs": [],
- "source": [
- "# context initialization\n",
- "with initialize(config_path=\"../training/conf/network/\", job_name=\"test_app\"):\n",
- " cfg = compose(config_name=\"conv_transformer\")\n",
- " print(OmegaConf.to_yaml(cfg))\n",
- " print(cfg)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
"id": "764c8736-7d68-4261-a57d-face10ebbf42",
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [
{
"name": "stdout",
@@ -122,8 +54,7 @@
" mode: min\n",
" verbose: false\n",
" dirpath: checkpoints/\n",
- " filename:\n",
- " epoch:02d: null\n",
+ " filename: '{epoch:02d}'\n",
" learning_rate_monitor:\n",
" _target_: pytorch_lightning.callbacks.LearningRateMonitor\n",
" logging_interval: step\n",
@@ -139,20 +70,20 @@
" _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact\n",
" ckpt_dir: checkpoints/\n",
" upload_best_only: true\n",
- " log_text_predictions:\n",
- " _target_: callbacks.wandb_callbacks.LogTextPredictions\n",
+ " log_image_reconstruction:\n",
+ " _target_: callbacks.wandb_callbacks.LogReconstuctedImages\n",
" num_samples: 8\n",
"criterion:\n",
- " _target_: text_recognizer.criterions.label_smoothing.LabelSmoothingLoss\n",
- " smoothing: 0.1\n",
- " ignore_index: 1002\n",
+ " _target_: torch.nn.MSELoss\n",
+ " reduction: mean\n",
"datamodule:\n",
" _target_: text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs\n",
- " batch_size: 8\n",
+ " batch_size: 32\n",
" num_workers: 12\n",
" train_fraction: 0.8\n",
" augment: true\n",
" pin_memory: false\n",
+ " word_pieces: true\n",
"logger:\n",
" wandb:\n",
" _target_: pytorch_lightning.loggers.wandb.WandbLogger\n",
@@ -170,8 +101,8 @@
" _target_: torch.optim.lr_scheduler.OneCycleLR\n",
" max_lr: 0.001\n",
" total_steps: null\n",
- " epochs: 512\n",
- " steps_per_epoch: 4992\n",
+ " epochs: 64\n",
+ " steps_per_epoch: 624\n",
" pct_start: 0.3\n",
" anneal_strategy: cos\n",
" cycle_momentum: true\n",
@@ -199,52 +130,21 @@
"\n",
" '\n",
"model:\n",
- " _target_: text_recognizer.models.transformer.TransformerLitModel\n",
+ " _target_: text_recognizer.models.vqvae.VQVAELitModel\n",
" interval: step\n",
" monitor: val/loss\n",
- " max_output_len: 451\n",
- " start_token: <s>\n",
- " end_token: <e>\n",
- " pad_token: <p>\n",
"network:\n",
- " encoder:\n",
- " _target_: text_recognizer.networks.encoders.efficientnet.EfficientNet\n",
- " arch: b0\n",
- " out_channels: 1280\n",
- " stochastic_dropout_rate: 0.2\n",
- " bn_momentum: 0.99\n",
- " bn_eps: 0.001\n",
- " decoder:\n",
- " _target_: text_recognizer.networks.transformer.Decoder\n",
- " dim: 96\n",
- " depth: 2\n",
- " num_heads: 8\n",
- " attn_fn: text_recognizer.networks.transformer.attention.Attention\n",
- " attn_kwargs:\n",
- " dim_head: 16\n",
- " dropout_rate: 0.2\n",
- " norm_fn: torch.nn.LayerNorm\n",
- " ff_fn: text_recognizer.networks.transformer.mlp.FeedForward\n",
- " ff_kwargs:\n",
- " dim_out: null\n",
- " expansion_factor: 4\n",
- " glu: true\n",
- " dropout_rate: 0.2\n",
- " cross_attend: true\n",
- " pre_norm: true\n",
- " rotary_emb: null\n",
- " _target_: text_recognizer.networks.conv_transformer.ConvTransformer\n",
- " input_dims:\n",
- " - 1\n",
- " - 576\n",
- " - 640\n",
- " hidden_dim: 96\n",
- " dropout_rate: 0.2\n",
- " num_classes: 1006\n",
- " pad_index: 1002\n",
+ " _target_: text_recognizer.networks.vqvae.VQVAE\n",
+ " in_channels: 1\n",
+ " res_channels: 32\n",
+ " num_residual_layers: 2\n",
+ " embedding_dim: 64\n",
+ " num_embeddings: 512\n",
+ " decay: 0.99\n",
+ " activation: mish\n",
"optimizer:\n",
" _target_: madgrad.MADGRAD\n",
- " lr: 0.001\n",
+ " lr: 0.01\n",
" momentum: 0.9\n",
" weight_decay: 0\n",
" eps: 1.0e-06\n",
@@ -257,7 +157,7 @@
" fast_dev_run: false\n",
" gpus: 1\n",
" precision: 16\n",
- " max_epochs: 512\n",
+ " max_epochs: 64\n",
" terminate_on_nan: true\n",
" weights_summary: top\n",
" limit_train_batches: 1.0\n",
@@ -269,91 +169,181 @@
"train: true\n",
"test: true\n",
"logging: INFO\n",
+ "work_dir: ${hydra:runtime.cwd}\n",
"debug: false\n",
+ "print_config: true\n",
+ "ignore_warnings: true\n",
"\n",
- "{'callbacks': {'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'monitor': 'val/loss', 'save_top_k': 1, 'save_last': True, 'mode': 'min', 'verbose': False, 'dirpath': 'checkpoints/', 'filename': {'epoch:02d': None}}, 'learning_rate_monitor': {'_target_': 'pytorch_lightning.callbacks.LearningRateMonitor', 'logging_interval': 'step', 'log_momentum': False}, 'watch_model': {'_target_': 'callbacks.wandb_callbacks.WatchModel', 'log': 'all', 'log_freq': 100}, 'upload_code_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCodeAsArtifact', 'project_dir': '${work_dir}/text_recognizer'}, 'upload_ckpts_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCheckpointsAsArtifact', 'ckpt_dir': 'checkpoints/', 'upload_best_only': True}, 'log_text_predictions': {'_target_': 'callbacks.wandb_callbacks.LogTextPredictions', 'num_samples': 8}}, 'criterion': {'_target_': 'text_recognizer.criterions.label_smoothing.LabelSmoothingLoss', 'smoothing': 0.1, 'ignore_index': 1002}, 'datamodule': {'_target_': 'text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs', 'batch_size': 8, 'num_workers': 12, 'train_fraction': 0.8, 'augment': True, 'pin_memory': False}, 'logger': {'wandb': {'_target_': 'pytorch_lightning.loggers.wandb.WandbLogger', 'project': 'text-recognizer', 'name': None, 'save_dir': '.', 'offline': False, 'id': None, 'log_model': False, 'prefix': '', 'job_type': 'train', 'group': '', 'tags': []}}, 'lr_scheduler': {'_target_': 'torch.optim.lr_scheduler.OneCycleLR', 'max_lr': 0.001, 'total_steps': None, 'epochs': 512, 'steps_per_epoch': 4992, 'pct_start': 0.3, 'anneal_strategy': 'cos', 'cycle_momentum': True, 'base_momentum': 0.85, 'max_momentum': 0.95, 'div_factor': 25.0, 'final_div_factor': 10000.0, 'three_phase': True, 'last_epoch': -1, 'verbose': False}, 'mapping': {'_target_': 'text_recognizer.data.word_piece_mapping.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['<s>', '<e>', '<p>'], 'extra_symbols': ['\\n']}, 'model': {'_target_': 'text_recognizer.models.transformer.TransformerLitModel', 'interval': 'step', 'monitor': 'val/loss', 'max_output_len': 451, 'start_token': '<s>', 'end_token': '<e>', 'pad_token': '<p>'}, 'network': {'encoder': {'_target_': 'text_recognizer.networks.encoders.efficientnet.EfficientNet', 'arch': 'b0', 'out_channels': 1280, 'stochastic_dropout_rate': 0.2, 'bn_momentum': 0.99, 'bn_eps': 0.001}, 'decoder': {'_target_': 'text_recognizer.networks.transformer.Decoder', 'dim': 96, 'depth': 2, 'num_heads': 8, 'attn_fn': 'text_recognizer.networks.transformer.attention.Attention', 'attn_kwargs': {'dim_head': 16, 'dropout_rate': 0.2}, 'norm_fn': 'torch.nn.LayerNorm', 'ff_fn': 'text_recognizer.networks.transformer.mlp.FeedForward', 'ff_kwargs': {'dim_out': None, 'expansion_factor': 4, 'glu': True, 'dropout_rate': 0.2}, 'cross_attend': True, 'pre_norm': True, 'rotary_emb': None}, '_target_': 'text_recognizer.networks.conv_transformer.ConvTransformer', 'input_dims': [1, 576, 640], 'hidden_dim': 96, 'dropout_rate': 0.2, 'num_classes': 1006, 'pad_index': 1002}, 'optimizer': {'_target_': 'madgrad.MADGRAD', 'lr': 0.001, 'momentum': 0.9, 'weight_decay': 0, 'eps': 1e-06}, 'trainer': {'_target_': 'pytorch_lightning.Trainer', 'stochastic_weight_avg': False, 'auto_scale_batch_size': 'binsearch', 'auto_lr_find': False, 'gradient_clip_val': 0, 'fast_dev_run': False, 'gpus': 1, 'precision': 16, 'max_epochs': 512, 'terminate_on_nan': True, 'weights_summary': 'top', 'limit_train_batches': 1.0, 'limit_val_batches': 1.0, 'limit_test_batches': 1.0, 'resume_from_checkpoint': None}, 'seed': 4711, 'tune': False, 'train': True, 'test': True, 'logging': 'INFO', 'debug': False}\n"
+ "{'callbacks': {'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'monitor': 'val/loss', 'save_top_k': 1, 'save_last': True, 'mode': 'min', 'verbose': False, 'dirpath': 'checkpoints/', 'filename': '{epoch:02d}'}, 'learning_rate_monitor': {'_target_': 'pytorch_lightning.callbacks.LearningRateMonitor', 'logging_interval': 'step', 'log_momentum': False}, 'watch_model': {'_target_': 'callbacks.wandb_callbacks.WatchModel', 'log': 'all', 'log_freq': 100}, 'upload_code_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCodeAsArtifact', 'project_dir': '${work_dir}/text_recognizer'}, 'upload_ckpts_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCheckpointsAsArtifact', 'ckpt_dir': 'checkpoints/', 'upload_best_only': True}, 'log_image_reconstruction': {'_target_': 'callbacks.wandb_callbacks.LogReconstuctedImages', 'num_samples': 8}}, 'criterion': {'_target_': 'torch.nn.MSELoss', 'reduction': 'mean'}, 'datamodule': {'_target_': 'text_recognizer.data.iam_extended_paragraphs.IAMExtendedParagraphs', 'batch_size': 32, 'num_workers': 12, 'train_fraction': 0.8, 'augment': True, 'pin_memory': False, 'word_pieces': True}, 'logger': {'wandb': {'_target_': 'pytorch_lightning.loggers.wandb.WandbLogger', 'project': 'text-recognizer', 'name': None, 'save_dir': '.', 'offline': False, 'id': None, 'log_model': False, 'prefix': '', 'job_type': 'train', 'group': '', 'tags': []}}, 'lr_scheduler': {'_target_': 'torch.optim.lr_scheduler.OneCycleLR', 'max_lr': 0.001, 'total_steps': None, 'epochs': 64, 'steps_per_epoch': 624, 'pct_start': 0.3, 'anneal_strategy': 'cos', 'cycle_momentum': True, 'base_momentum': 0.85, 'max_momentum': 0.95, 'div_factor': 25.0, 'final_div_factor': 10000.0, 'three_phase': True, 'last_epoch': -1, 'verbose': False}, 'mapping': {'_target_': 'text_recognizer.data.word_piece_mapping.WordPieceMapping', 'num_features': 1000, 'tokens': 'iamdb_1kwp_tokens_1000.txt', 'lexicon': 'iamdb_1kwp_lex_1000.txt', 'data_dir': None, 'use_words': False, 'prepend_wordsep': False, 'special_tokens': ['<s>', '<e>', '<p>'], 'extra_symbols': ['\\n']}, 'model': {'_target_': 'text_recognizer.models.vqvae.VQVAELitModel', 'interval': 'step', 'monitor': 'val/loss'}, 'network': {'_target_': 'text_recognizer.networks.vqvae.VQVAE', 'in_channels': 1, 'res_channels': 32, 'num_residual_layers': 2, 'embedding_dim': 64, 'num_embeddings': 512, 'decay': 0.99, 'activation': 'mish'}, 'optimizer': {'_target_': 'madgrad.MADGRAD', 'lr': 0.01, 'momentum': 0.9, 'weight_decay': 0, 'eps': 1e-06}, 'trainer': {'_target_': 'pytorch_lightning.Trainer', 'stochastic_weight_avg': False, 'auto_scale_batch_size': 'binsearch', 'auto_lr_find': False, 'gradient_clip_val': 0, 'fast_dev_run': False, 'gpus': 1, 'precision': 16, 'max_epochs': 64, 'terminate_on_nan': True, 'weights_summary': 'top', 'limit_train_batches': 1.0, 'limit_val_batches': 1.0, 'limit_test_batches': 1.0, 'resume_from_checkpoint': None}, 'seed': 4711, 'tune': False, 'train': True, 'test': True, 'logging': 'INFO', 'work_dir': '${hydra:runtime.cwd}', 'debug': False, 'print_config': True, 'ignore_warnings': True}\n"
]
}
],
"source": [
"# context initialization\n",
"with initialize(config_path=\"../training/conf/\", job_name=\"test_app\"):\n",
- " cfg = compose(config_name=\"config\")\n",
+ " cfg = compose(config_name=\"config\", overrides=[\"+experiment=vqvae\"])\n",
" print(OmegaConf.to_yaml(cfg))\n",
" print(cfg)"
]
},
{
"cell_type": "code",
- "execution_count": 10,
- "id": "9382f0ab-8760-4d59-b0b5-b8b65dd1ea31",
+ "execution_count": 4,
+ "id": "c1a9aa6b-6405-4ffe-b065-02340762476a",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2021-08-04 04:49:04.188 | DEBUG | text_recognizer.data.word_piece_mapping:__init__:37 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n"
+ ]
+ }
+ ],
+ "source": [
+ "mapping = instantiate(cfg.mapping)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "id": "969ba3be-d78f-4b1e-b522-ea8a42669e86",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "network = instantiate(cfg.network)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "id": "6147cd3e-0ad1-490f-917d-21be9bb8ce1c",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x = torch.rand(1, 1, 576, 640)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "id": "a0ecea0c-abaf-4d5d-a13d-c085c1e4d282",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "{'model_checkpoint': {'_target_': 'pytorch_lightning.callbacks.ModelCheckpoint', 'monitor': 'val/loss', 'save_top_k': 1, 'save_last': True, 'mode': 'min', 'verbose': False, 'dirpath': 'checkpoints/', 'filename': {'epoch:02d': None}}, 'learning_rate_monitor': {'_target_': 'pytorch_lightning.callbacks.LearningRateMonitor', 'logging_interval': 'step', 'log_momentum': False}, 'watch_model': {'_target_': 'callbacks.wandb_callbacks.WatchModel', 'log': 'all', 'log_freq': 100}, 'upload_code_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCodeAsArtifact', 'project_dir': '${work_dir}/text_recognizer'}, 'upload_ckpts_as_artifact': {'_target_': 'callbacks.wandb_callbacks.UploadCheckpointsAsArtifact', 'ckpt_dir': 'checkpoints/', 'upload_best_only': True}, 'log_text_predictions': {'_target_': 'callbacks.wandb_callbacks.LogTextPredictions', 'num_samples': 8}}"
+ "torch.Size([1, 64, 144, 160])"
]
},
- "execution_count": 10,
+ "execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "cfg.get(\"callbacks\")"
+ "network.encode(x)[0].shape"
]
},
{
"cell_type": "code",
- "execution_count": 12,
- "id": "216d5680-66bf-4190-9401-1a59dbbc43af",
+ "execution_count": 38,
+ "id": "a7b9f249-7e5e-4f31-bbe1-cfd6d3701cf0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
- "pytorch_lightning.callbacks.ModelCheckpoint\n",
- "pytorch_lightning.callbacks.LearningRateMonitor\n",
- "callbacks.wandb_callbacks.WatchModel\n",
- "callbacks.wandb_callbacks.UploadCodeAsArtifact\n",
- "callbacks.wandb_callbacks.UploadCheckpointsAsArtifact\n",
- "callbacks.wandb_callbacks.LogTextPredictions\n"
+ "torch.Size([512])\n",
+ "torch.Size([512])\n",
+ "torch.Size([512])\n",
+ "torch.Size([512])\n"
]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([1, 1, 576, 640])"
+ ]
+ },
+ "execution_count": 38,
+ "metadata": {},
+ "output_type": "execute_result"
}
],
"source": [
- "for l in cfg.callbacks.values():\n",
- " print(l.get(\"_target_\"))"
+ "network(x)[0].shape"
]
},
{
"cell_type": "code",
- "execution_count": 4,
- "id": "c1a9aa6b-6405-4ffe-b065-02340762476a",
+ "execution_count": null,
+ "id": "23c9d90c-042b-423e-ab85-18449e29ded4",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "2021-08-03 15:27:02.069 | DEBUG | text_recognizer.data.word_piece_mapping:__init__:37 - Using data dir: /home/aktersnurra/projects/text-recognizer/data/downloaded/iam/iamdb\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
- "mapping = instantiate(cfg.mapping)"
+ "576 / 4"
]
},
{
"cell_type": "code",
- "execution_count": 5,
- "id": "969ba3be-d78f-4b1e-b522-ea8a42669e86",
+ "execution_count": null,
+ "id": "047ebc09-1c74-44a7-a314-1099f09722fe",
"metadata": {},
"outputs": [],
"source": [
- "network = instantiate(cfg.network)"
+ "t = torch.randint(0, 1006, (1, 451)).cuda()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "87372dde-2b1a-432b-ab79-0b116124c724",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "z = torch.rand((1, 36 * 40, 128)).cuda()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "cf7ca9bf-cafa-4128-9db7-046c16933a52",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "network = network.cuda()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "dfceaa5f-9ad8-4d33-addb-c56e8da48356",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "network.decode(z, t).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9105fbbb-4363-4d3e-a01e-bc519c3b9c3a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "decoder = decoder.cuda()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "c5797ec4-7a6a-46fd-8adc-265df44d0341",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "decoder(z, t).shape"
]
},
{
@@ -368,11 +358,9 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"id": "a6fae1fa-492d-4648-80fd-1c0dac659b02",
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [],
"source": [
"datamodule = instantiate(cfg.datamodule, mapping=mapping)"
@@ -380,19 +368,10 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"id": "514053ef-fcac-4f3c-a7c8-72c6927d6798",
"metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "2021-08-03 15:28:22.541 | INFO | text_recognizer.data.iam_paragraphs:setup:95 - Loading IAM paragraph regions and lines for None...\n",
- "2021-08-03 15:28:45.280 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:68 - IAM Synthetic dataset steup for stage None...\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"datamodule.prepare_data()\n",
"datamodule.setup()"
@@ -400,21 +379,10 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": null,
"id": "4bad950b-a197-4c60-ad89-903124659a98",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "4992"
- ]
- },
- "execution_count": 11,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"len(datamodule.train_dataloader())"
]
@@ -431,7 +399,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"id": "f6e01c15-9a1b-4036-87ae-78716c592264",
"metadata": {},
"outputs": [],
@@ -441,7 +409,7 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"id": "4dc475fc-31f4-487e-88c8-b0f445131f5b",
"metadata": {},
"outputs": [],
@@ -451,7 +419,7 @@
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"id": "c5c8ed64-d98c-47b5-baf2-1ba57a6c882f",
"metadata": {},
"outputs": [],
@@ -461,11 +429,9 @@
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": null,
"id": "b5ff5b24-f804-402b-a8ab-f366443025ca",
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [],
"source": [
" model = hydra.utils.instantiate(\n",
@@ -481,21 +447,10 @@
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": null,
"id": "99f8a39f-8b10-4f7d-8bff-52794fd48717",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<bound method WordPieceMapping.get_index of <text_recognizer.data.word_piece_mapping.WordPieceMapping object at 0x7fae3b489610>>"
- ]
- },
- "execution_count": 11,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"mapping.get_index"
]
@@ -514,9 +469,7 @@
"cell_type": "code",
"execution_count": null,
"id": "8f0742ad-5e2f-42d5-83e7-6e46398b4f0f",
- "metadata": {
- "tags": []
- },
+ "metadata": {},
"outputs": [],
"source": [
"net"
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
index 22da018..5890fd9 100644
--- a/text_recognizer/models/vqvae.py
+++ b/text_recognizer/models/vqvae.py
@@ -14,31 +14,33 @@ from text_recognizer.models.base import BaseLitModel
class VQVAELitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
+ latent_loss_weight: float = attr.ib(default=0.25)
+
def forward(self, data: Tensor) -> Tensor:
"""Forward pass with the transformer network."""
- return self.network.predict(data)
+ return self.network(data)
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
"""Training step."""
data, _ = batch
- reconstructions, vq_loss = self.network(data)
+ reconstructions, vq_loss = self(data)
loss = self.loss_fn(reconstructions, data)
- loss += vq_loss
+ loss += self.latent_loss_weight * vq_loss
self.log("train/loss", loss)
return loss
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Validation step."""
data, _ = batch
- reconstructions, vq_loss = self.network(data)
+ reconstructions, vq_loss = self(data)
loss = self.loss_fn(reconstructions, data)
- loss += vq_loss
+ loss += self.latent_loss_weight * vq_loss
self.log("val/loss", loss, prog_bar=True)
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Test step."""
data, _ = batch
- reconstructions, vq_loss = self.network(data)
+ reconstructions, vq_loss = self(data)
loss = self.loss_fn(reconstructions, data)
- loss += vq_loss
+ loss += self.latent_loss_weight * vq_loss
self.log("test/loss", loss)
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py
new file mode 100644
index 0000000..a972565
--- /dev/null
+++ b/text_recognizer/networks/vq_transformer.py
@@ -0,0 +1,77 @@
+"""Vector quantized encoder, transformer decoder."""
+import math
+from typing import Tuple
+
+from torch import nn, Tensor
+
+from text_recognizer.networks.encoders.efficientnet import EfficientNet
+from text_recognizer.networks.conv_transformer import ConvTransformer
+from text_recognizer.networks.transformer.layers import Decoder
+from text_recognizer.networks.transformer.positional_encodings import (
+ PositionalEncoding,
+ PositionalEncoding2D,
+)
+
+
+class VqTransformer(ConvTransformer):
+ """Convolutional encoder and transformer decoder network."""
+
+ def __init__(
+ self,
+ input_dims: Tuple[int, int, int],
+ hidden_dim: int,
+ dropout_rate: float,
+ num_classes: int,
+ pad_index: Tensor,
+ encoder: EfficientNet,
+ decoder: Decoder,
+ ) -> None:
+ # TODO: Load pretrained vqvae encoder.
+ super().__init__(
+ input_dims=input_dims,
+ hidden_dim=hidden_dim,
+ dropout_rate=dropout_rate,
+ num_classes=num_classes,
+ pad_index=pad_index,
+ encoder=encoder,
+ decoder=decoder,
+ )
+ # Latent projector for down sampling number of filters and 2d
+ # positional encoding.
+ self.latent_encoder = nn.Sequential(
+ nn.Conv2d(
+ in_channels=self.encoder.out_channels,
+ out_channels=self.hidden_dim,
+ kernel_size=1,
+ ),
+ PositionalEncoding2D(
+ hidden_dim=self.hidden_dim,
+ max_h=self.input_dims[1],
+ max_w=self.input_dims[2],
+ ),
+ nn.Flatten(start_dim=2),
+ )
+
+ def encode(self, x: Tensor) -> Tensor:
+ """Encodes an image into a latent feature vector.
+
+ Args:
+ x (Tensor): Image tensor.
+
+ Shape:
+ - x: :math: `(B, C, H, W)`
+ - z: :math: `(B, Sx, E)`
+
+ where Sx is the length of the flattened feature maps projected from
+ the encoder. E latent dimension for each pixel in the projected
+ feature maps.
+
+ Returns:
+ Tensor: A Latent embedding of the image.
+ """
+ z = self.encoder(x)
+ z = self.latent_encoder(z)
+
+ # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
+ z = z.permute(0, 2, 1)
+ return z
diff --git a/text_recognizer/networks/vqvae/__init__.py b/text_recognizer/networks/vqvae/__init__.py
index 763953c..7d56bdb 100644
--- a/text_recognizer/networks/vqvae/__init__.py
+++ b/text_recognizer/networks/vqvae/__init__.py
@@ -1,5 +1,2 @@
"""VQ-VAE module."""
-from .decoder import Decoder
-from .encoder import Encoder
-from .vector_quantizer import VectorQuantizer
from .vqvae import VQVAE
diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
index 32de912..3f59f0d 100644
--- a/text_recognizer/networks/vqvae/decoder.py
+++ b/text_recognizer/networks/vqvae/decoder.py
@@ -1,133 +1,65 @@
"""CNN decoder for the VQ-VAE."""
-
-from typing import List, Optional, Tuple, Type
-
-import torch
+import attr
from torch import nn
from torch import Tensor
from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.vqvae.encoder import _ResidualBlock
+from text_recognizer.networks.vqvae.residual import Residual
+@attr.s(eq=False)
class Decoder(nn.Module):
"""A CNN encoder network."""
- def __init__(
- self,
- channels: List[int],
- kernel_sizes: List[int],
- strides: List[int],
- num_residual_layers: int,
- embedding_dim: int,
- upsampling: Optional[List[List[int]]] = None,
- activation: str = "leaky_relu",
- dropout_rate: float = 0.0,
- ) -> None:
- super().__init__()
-
- if dropout_rate:
- if activation == "selu":
- dropout = nn.AlphaDropout(p=dropout_rate)
- else:
- dropout = nn.Dropout(p=dropout_rate)
- else:
- dropout = None
-
- self.upsampling = upsampling
-
- self.res_block = nn.ModuleList([])
- self.upsampling_block = nn.ModuleList([])
-
- self.embedding_dim = embedding_dim
- activation = activation_function(activation)
-
- # Configure encoder.
- self.decoder = self._build_decoder(
- channels, kernel_sizes, strides, num_residual_layers, activation, dropout,
- )
-
- def _build_decompression_block(
- self,
- in_channels: int,
- channels: int,
- kernel_sizes: List[int],
- strides: List[int],
- activation: Type[nn.Module],
- dropout: Optional[Type[nn.Module]],
- ) -> nn.ModuleList:
- modules = nn.ModuleList([])
- configuration = zip(channels, kernel_sizes, strides)
- for i, (out_channels, kernel_size, stride) in enumerate(configuration):
- modules.append(
- nn.Sequential(
- nn.ConvTranspose2d(
- in_channels,
- out_channels,
- kernel_size,
- stride=stride,
- padding=1,
- ),
- activation,
- )
- )
-
- if self.upsampling and i < len(self.upsampling):
- modules.append(nn.Upsample(size=self.upsampling[i]),)
+ in_channels: int = attr.ib()
+ embedding_dim: int = attr.ib()
+ out_channels: int = attr.ib()
+ res_channels: int = attr.ib()
+ num_residual_layers: int = attr.ib()
+ activation: str = attr.ib()
+ decoder: nn.Sequential = attr.ib(init=False)
- if dropout is not None:
- modules.append(dropout)
-
- in_channels = out_channels
-
- modules.extend(
- nn.Sequential(
- nn.ConvTranspose2d(
- in_channels, 1, kernel_size=kernel_size, stride=stride, padding=1
- ),
- nn.Tanh(),
- )
- )
-
- return modules
-
- def _build_decoder(
- self,
- channels: int,
- kernel_sizes: List[int],
- strides: List[int],
- num_residual_layers: int,
- activation: Type[nn.Module],
- dropout: Optional[Type[nn.Module]],
- ) -> nn.Sequential:
-
- self.res_block.append(
- nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,)
- )
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
- # Bottleneck module.
- self.res_block.extend(
- nn.ModuleList(
- [
- _ResidualBlock(channels[0], channels[0], dropout)
- for i in range(num_residual_layers)
- ]
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+ self.decoder = self._build_decompression_block()
+
+ def _build_decompression_block(self,) -> nn.Sequential:
+ activation_fn = activation_function(self.activation)
+ blocks = [
+ nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.embedding_dim,
+ kernel_size=3,
+ padding=1,
)
- )
-
- # Decompression module
- self.upsampling_block.extend(
- self._build_decompression_block(
- channels[0], channels[1:], kernel_sizes, strides, activation, dropout
+ ]
+ for _ in range(self.num_residual_layers):
+ blocks.append(
+ Residual(in_channels=self.embedding_dim, out_channels=self.res_channels)
)
- )
-
- self.res_block = nn.Sequential(*self.res_block)
- self.upsampling_block = nn.Sequential(*self.upsampling_block)
-
- return nn.Sequential(self.res_block, self.upsampling_block)
+ blocks.append(activation_fn)
+ blocks += [
+ nn.ConvTranspose2d(
+ in_channels=self.embedding_dim,
+ out_channels=self.embedding_dim // 2,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ ),
+ activation_fn,
+ nn.ConvTranspose2d(
+ in_channels=self.embedding_dim // 2,
+ out_channels=self.out_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ ),
+ ]
+ return nn.Sequential(*blocks)
def forward(self, z_q: Tensor) -> Tensor:
"""Reconstruct input from given codes."""
- x_reconstruction = self.decoder(z_q)
- return x_reconstruction
+ return self.decoder(z_q)
diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
index 65801df..e480545 100644
--- a/text_recognizer/networks/vqvae/encoder.py
+++ b/text_recognizer/networks/vqvae/encoder.py
@@ -1,147 +1,75 @@
"""CNN encoder for the VQ-VAE."""
from typing import Sequence, Optional, Tuple, Type
-import torch
+import attr
from torch import nn
from torch import Tensor
from text_recognizer.networks.util import activation_function
-from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
-
-
-class _ResidualBlock(nn.Module):
- def __init__(
- self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]],
- ) -> None:
- super().__init__()
- self.block = [
- nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
- nn.ReLU(inplace=True),
- nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False),
- ]
-
- if dropout is not None:
- self.block.append(dropout)
-
- self.block = nn.Sequential(*self.block)
-
- def forward(self, x: Tensor) -> Tensor:
- """Apply the residual forward pass."""
- return x + self.block(x)
+from text_recognizer.networks.vqvae.residual import Residual
+@attr.s(eq=False)
class Encoder(nn.Module):
"""A CNN encoder network."""
- def __init__(
- self,
- in_channels: int,
- channels: Sequence[int],
- kernel_sizes: Sequence[int],
- strides: Sequence[int],
- num_residual_layers: int,
- embedding_dim: int,
- num_embeddings: int,
- beta: float = 0.25,
- activation: str = "leaky_relu",
- dropout_rate: float = 0.0,
- ) -> None:
- super().__init__()
-
- if dropout_rate:
- if activation == "selu":
- dropout = nn.AlphaDropout(p=dropout_rate)
- else:
- dropout = nn.Dropout(p=dropout_rate)
- else:
- dropout = None
-
- self.embedding_dim = embedding_dim
- self.num_embeddings = num_embeddings
- self.beta = beta
- activation = activation_function(activation)
-
- # Configure encoder.
- self.encoder = self._build_encoder(
- in_channels,
- channels,
- kernel_sizes,
- strides,
- num_residual_layers,
- activation,
- dropout,
- )
+ in_channels: int = attr.ib()
+ out_channels: int = attr.ib()
+ res_channels: int = attr.ib()
+ num_residual_layers: int = attr.ib()
+ embedding_dim: int = attr.ib()
+ activation: str = attr.ib()
+ encoder: nn.Sequential = attr.ib(init=False)
- # Configure Vector Quantizer.
- self.vector_quantizer = VectorQuantizer(
- self.num_embeddings, self.embedding_dim, self.beta
- )
-
- @staticmethod
- def _build_compression_block(
- in_channels: int,
- channels: int,
- kernel_sizes: Sequence[int],
- strides: Sequence[int],
- activation: Type[nn.Module],
- dropout: Optional[Type[nn.Module]],
- ) -> nn.ModuleList:
- modules = nn.ModuleList([])
- configuration = zip(channels, kernel_sizes, strides)
- for out_channels, kernel_size, stride in configuration:
- modules.append(
- nn.Sequential(
- nn.Conv2d(
- in_channels, out_channels, kernel_size, stride=stride, padding=1
- ),
- activation,
- )
- )
-
- if dropout is not None:
- modules.append(dropout)
-
- in_channels = out_channels
-
- return modules
+ def __attrs_pre_init__(self) -> None:
+ super().__init__()
- def _build_encoder(
- self,
- in_channels: int,
- channels: int,
- kernel_sizes: Sequence[int],
- strides: Sequence[int],
- num_residual_layers: int,
- activation: Type[nn.Module],
- dropout: Optional[Type[nn.Module]],
- ) -> nn.Sequential:
- encoder = nn.ModuleList([])
+ def __attrs_post_init__(self) -> None:
+ """Post init configuration."""
+ self.encoder = self._build_compression_block()
+
+ def _build_compression_block(self) -> nn.Sequential:
+ activation_fn = activation_function(self.activation)
+ block = [
+ nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.out_channels // 2,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ ),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=self.out_channels // 2,
+ out_channels=self.out_channels,
+ kernel_size=4,
+ stride=2,
+ padding=1,
+ ),
+ activation_fn,
+ nn.Conv2d(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels,
+ kernel_size=3,
+ padding=1,
+ ),
+ ]
- # compression module
- encoder.extend(
- self._build_compression_block(
- in_channels, channels, kernel_sizes, strides, activation, dropout
+ for _ in range(self.num_residual_layers):
+ block.append(
+ Residual(in_channels=self.out_channels, out_channels=self.res_channels)
)
- )
- # Bottleneck module.
- encoder.extend(
- nn.ModuleList(
- [
- _ResidualBlock(channels[-1], channels[-1], dropout)
- for i in range(num_residual_layers)
- ]
+ block.append(
+ nn.Conv2d(
+ in_channels=self.out_channels,
+ out_channels=self.embedding_dim,
+ kernel_size=1,
)
)
- encoder.append(
- nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,)
- )
-
- return nn.Sequential(*encoder)
+ return nn.Sequential(*block)
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""Encodes input into a discrete representation."""
- z_e = self.encoder(x)
- z_q, vq_loss = self.vector_quantizer(z_e)
- return z_q, vq_loss
+ return self.encoder(x)
diff --git a/text_recognizer/networks/vqvae/vector_quantizer.py b/text_recognizer/networks/vqvae/quantizer.py
index f92c7ee..5e0b602 100644
--- a/text_recognizer/networks/vqvae/vector_quantizer.py
+++ b/text_recognizer/networks/vqvae/quantizer.py
@@ -2,9 +2,7 @@
Reference:
https://github.com/AntixK/PyTorch-VAE/blob/master/models/vq_vae.py
-
"""
-
from einops import rearrange
import torch
from torch import nn
@@ -12,21 +10,27 @@ from torch import Tensor
from torch.nn import functional as F
+class EmbeddingEMA(nn.Module):
+ def __init__(self, num_embeddings: int, embedding_dim: int) -> None:
+ super().__init__()
+ weight = torch.zeros(num_embeddings, embedding_dim)
+ nn.init.kaiming_uniform_(weight, nonlinearity="linear")
+ self.register_buffer("weight", weight)
+ self.register_buffer("_cluster_size", torch.zeros(num_embeddings))
+ self.register_buffer("_weight_avg", weight)
+
+
class VectorQuantizer(nn.Module):
"""The codebook that contains quantized vectors."""
def __init__(
- self, num_embeddings: int, embedding_dim: int, beta: float = 0.25
+ self, num_embeddings: int, embedding_dim: int, decay: float = 0.99
) -> None:
super().__init__()
- self.K = num_embeddings
- self.D = embedding_dim
- self.beta = beta
-
- self.embedding = nn.Embedding(self.K, self.D)
-
- # Initialize the codebook.
- nn.init.uniform_(self.embedding.weight, -1 / self.K, 1 / self.K)
+ self.num_embeddings = num_embeddings
+ self.embedding_dim = embedding_dim
+ self.decay = decay
+ self.embedding = EmbeddingEMA(self.num_embeddings, self.embedding_dim)
def discretization_bottleneck(self, latent: Tensor) -> Tensor:
"""Computes the code nearest to the latent representation.
@@ -62,7 +66,7 @@ class VectorQuantizer(nn.Module):
# Convert to one-hot encodings, aka discrete bottleneck.
one_hot_encoding = torch.zeros(
- encoding_indices.shape[0], self.K, device=latent.device
+ encoding_indices.shape[0], self.num_embeddings, device=latent.device
)
one_hot_encoding.scatter_(1, encoding_indices, 1) # [BHW x K]
@@ -71,9 +75,27 @@ class VectorQuantizer(nn.Module):
quantized_latent = rearrange(
quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w
)
+ if self.training:
+ self.compute_ema(one_hot_encoding=one_hot_encoding, latent=latent)
return quantized_latent
+ def compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None:
+ batch_cluster_size = one_hot_encoding.sum(axis=0)
+ batch_embedding_avg = (latent.t() @ one_hot_encoding).t()
+ print(batch_cluster_size.shape)
+ print(self.embedding._cluster_size.shape)
+ self.embedding._cluster_size.data.mul_(self.decay).add_(
+ batch_cluster_size, alpha=1 - self.decay
+ )
+ self.embedding._weight_avg.data.mul_(self.decay).add_(
+ batch_embedding_avg, alpha=1 - self.decay
+ )
+ new_embedding = self.embedding._weight_avg / (
+ self.embedding._cluster_size + 1.0e-5
+ ).unsqueeze(1)
+ self.embedding.weight.data.copy_(new_embedding)
+
def vq_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor:
"""Vector Quantization loss.
@@ -96,9 +118,10 @@ class VectorQuantizer(nn.Module):
Tensor: The combinded VQ loss.
"""
- embedding_loss = F.mse_loss(quantized_latent, latent.detach())
commitment_loss = F.mse_loss(quantized_latent.detach(), latent)
- return embedding_loss + self.beta * commitment_loss
+ # embedding_loss = F.mse_loss(quantized_latent, latent.detach())
+ # return embedding_loss + self.beta * commitment_loss
+ return commitment_loss
def forward(self, latent: Tensor) -> Tensor:
"""Forward pass that returns the quantized vector and the vq loss."""
diff --git a/text_recognizer/networks/vqvae/residual.py b/text_recognizer/networks/vqvae/residual.py
new file mode 100644
index 0000000..98109b8
--- /dev/null
+++ b/text_recognizer/networks/vqvae/residual.py
@@ -0,0 +1,18 @@
+"""Residual block."""
+from torch import nn
+from torch import Tensor
+
+
+class Residual(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int,) -> None:
+ super().__init__()
+ self.block = nn.Sequential(
+ nn.Mish(inplace=True),
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
+ nn.Mish(inplace=True),
+ nn.Conv2d(out_channels, in_channels, kernel_size=1, bias=False),
+ )
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Apply the residual forward pass."""
+ return x + self.block(x)
diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py
index 5aa929b..1585d40 100644
--- a/text_recognizer/networks/vqvae/vqvae.py
+++ b/text_recognizer/networks/vqvae/vqvae.py
@@ -1,10 +1,14 @@
"""The VQ-VAE."""
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Tuple
+import torch
from torch import nn
from torch import Tensor
+import torch.nn.functional as F
-from text_recognizer.networks.vqvae import Decoder, Encoder
+from text_recognizer.networks.vqvae.decoder import Decoder
+from text_recognizer.networks.vqvae.encoder import Encoder
+from text_recognizer.networks.vqvae.quantizer import VectorQuantizer
class VQVAE(nn.Module):
@@ -13,62 +17,92 @@ class VQVAE(nn.Module):
def __init__(
self,
in_channels: int,
- channels: List[int],
- kernel_sizes: List[int],
- strides: List[int],
+ res_channels: int,
num_residual_layers: int,
embedding_dim: int,
num_embeddings: int,
- upsampling: Optional[List[List[int]]] = None,
- beta: float = 0.25,
- activation: str = "leaky_relu",
- dropout_rate: float = 0.0,
- *args: Any,
- **kwargs: Dict,
+ decay: float = 0.99,
+ activation: str = "mish",
) -> None:
super().__init__()
+ # Encoders
+ self.btm_encoder = Encoder(
+ in_channels=1,
+ out_channels=embedding_dim,
+ res_channels=res_channels,
+ num_residual_layers=num_residual_layers,
+ embedding_dim=embedding_dim,
+ activation=activation,
+ )
+
+ self.top_encoder = Encoder(
+ in_channels=embedding_dim,
+ out_channels=embedding_dim,
+ res_channels=res_channels,
+ num_residual_layers=num_residual_layers,
+ embedding_dim=embedding_dim,
+ activation=activation,
+ )
+
+ # Quantizers
+ self.btm_quantizer = VectorQuantizer(
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay,
+ )
- # configure encoder.
- self.encoder = Encoder(
- in_channels,
- channels,
- kernel_sizes,
- strides,
- num_residual_layers,
- embedding_dim,
- num_embeddings,
- beta,
- activation,
- dropout_rate,
+ self.top_quantizer = VectorQuantizer(
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay,
)
- # Configure decoder.
- channels.reverse()
- kernel_sizes.reverse()
- strides.reverse()
- self.decoder = Decoder(
- channels,
- kernel_sizes,
- strides,
- num_residual_layers,
- embedding_dim,
- upsampling,
- activation,
- dropout_rate,
+ # Decoders
+ self.top_decoder = Decoder(
+ in_channels=embedding_dim,
+ out_channels=embedding_dim,
+ embedding_dim=embedding_dim,
+ res_channels=res_channels,
+ num_residual_layers=num_residual_layers,
+ activation=activation,
+ )
+
+ self.btm_decoder = Decoder(
+ in_channels=2 * embedding_dim,
+ out_channels=in_channels,
+ embedding_dim=embedding_dim,
+ res_channels=res_channels,
+ num_residual_layers=num_residual_layers,
+ activation=activation,
)
def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""Encodes input to a latent code."""
- return self.encoder(x)
+ z_btm = self.btm_encoder(x)
+ z_top = self.top_encoder(z_btm)
+ return z_btm, z_top
+
+ def quantize(
+ self, z_btm: Tensor, z_top: Tensor
+ ) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
+ q_btm, vq_btm_loss = self.top_quantizer(z_btm)
+ q_top, vq_top_loss = self.top_quantizer(z_top)
+ return q_btm, vq_btm_loss, q_top, vq_top_loss
- def decode(self, z_q: Tensor) -> Tensor:
+ def decode(self, q_btm: Tensor, q_top: Tensor) -> Tuple[Tensor, Tensor]:
"""Reconstructs input from latent codes."""
- return self.decoder(z_q)
+ d_top = self.top_decoder(q_top)
+ x_hat = self.btm_decoder(torch.cat((d_top, q_btm), dim=1))
+ return d_top, x_hat
+
+ def loss_fn(
+ self, vq_btm_loss: Tensor, vq_top_loss: Tensor, d_top: Tensor, z_btm: Tensor
+ ) -> Tensor:
+ """Calculates the latent loss."""
+ return 0.5 * (vq_top_loss + vq_btm_loss) + F.mse_loss(d_top, z_btm)
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""Compresses and decompresses input."""
- if len(x.shape) < 4:
- x = x[(None,) * (4 - len(x.shape))]
- z_q, vq_loss = self.encode(x)
- x_reconstruction = self.decode(z_q)
- return x_reconstruction, vq_loss
+ z_btm, z_top = self.encode(x)
+ q_btm, vq_btm_loss, q_top, vq_top_loss = self.quantize(z_btm=z_btm, z_top=z_top)
+ d_top, x_hat = self.decode(q_btm=q_btm, q_top=q_top)
+ vq_loss = self.loss_fn(
+ vq_btm_loss=vq_btm_loss, vq_top_loss=vq_top_loss, d_top=d_top, z_btm=z_btm
+ )
+ return x_hat, vq_loss
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py
index 906531f..c750e4b 100644
--- a/training/callbacks/wandb_callbacks.py
+++ b/training/callbacks/wandb_callbacks.py
@@ -5,6 +5,7 @@ import wandb
from pytorch_lightning import Callback, LightningModule, Trainer
from pytorch_lightning.loggers import LoggerCollection, WandbLogger
from pytorch_lightning.utilities import rank_zero_only
+from torch.utils.data import DataLoader
def get_wandb_logger(trainer: Trainer) -> WandbLogger:
@@ -86,7 +87,11 @@ class LogTextPredictions(Callback):
self.ready = False
def _log_predictions(
- self, stage: str, trainer: Trainer, pl_module: LightningModule
+ self,
+ stage: str,
+ trainer: Trainer,
+ pl_module: LightningModule,
+ dataloader: DataLoader,
) -> None:
"""Logs the predicted text contained in the images."""
if not self.ready:
@@ -96,22 +101,20 @@ class LogTextPredictions(Callback):
experiment = logger.experiment
# Get a validation batch from the validation dataloader.
- samples = next(iter(trainer.datamodule.val_dataloader()))
+ samples = next(iter(dataloader))
imgs, labels = samples
imgs = imgs.to(device=pl_module.device)
logits = pl_module(imgs)
mapping = pl_module.mapping
- columns = ["id", "image", "prediction", "truth"]
+ columns = ["image", "prediction", "truth"]
data = [
- [id, wandb.Image(img), mapping.get_text(pred), mapping.get_text(label)]
- for id, (img, pred, label) in enumerate(
- zip(
- imgs[: self.num_samples],
- logits[: self.num_samples],
- labels[: self.num_samples],
- )
+ [wandb.Image(img), mapping.get_text(pred), mapping.get_text(label)]
+ for img, pred, label in zip(
+ imgs[: self.num_samples],
+ logits[: self.num_samples],
+ labels[: self.num_samples],
)
]
@@ -133,11 +136,17 @@ class LogTextPredictions(Callback):
self, trainer: Trainer, pl_module: LightningModule
) -> None:
"""Logs predictions on validation epoch end."""
- self._log_predictions(stage="val", trainer=trainer, pl_module=pl_module)
+ dataloader = trainer.datamodule.val_dataloader()
+ self._log_predictions(
+ stage="val", trainer=trainer, pl_module=pl_module, dataloader=dataloader
+ )
def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Logs predictions on train epoch end."""
- self._log_predictions(stage="test", trainer=trainer, pl_module=pl_module)
+ dataloader = trainer.datamodule.test_dataloader()
+ self._log_predictions(
+ stage="test", trainer=trainer, pl_module=pl_module, dataloader=dataloader
+ )
class LogReconstuctedImages(Callback):
@@ -148,7 +157,11 @@ class LogReconstuctedImages(Callback):
self.ready = False
def _log_reconstruction(
- self, stage: str, trainer: Trainer, pl_module: LightningModule
+ self,
+ stage: str,
+ trainer: Trainer,
+ pl_module: LightningModule,
+ dataloader: DataLoader,
) -> None:
"""Logs the reconstructions."""
if not self.ready:
@@ -158,20 +171,24 @@ class LogReconstuctedImages(Callback):
experiment = logger.experiment
# Get a validation batch from the validation dataloader.
- samples = next(iter(trainer.datamodule.val_dataloader()))
+ samples = next(iter(dataloader))
imgs, _ = samples
+ colums = ["input", "reconstruction"]
imgs = imgs.to(device=pl_module.device)
- reconstructions = pl_module(imgs)
+ reconstructions = pl_module(imgs)[0]
+ data = [
+ [wandb.Image(img), wandb.Image(rec)]
+ for img, rec in zip(
+ imgs[: self.num_samples], reconstructions[: self.num_samples]
+ )
+ ]
experiment.log(
{
- f"Reconstructions/{experiment.name}/{stage}": [
- [wandb.Image(img), wandb.Image(rec),]
- for img, rec in zip(
- imgs[: self.num_samples], reconstructions[: self.num_samples],
- )
- ]
+ f"Reconstructions/{experiment.name}/{stage}": wandb.Table(
+ data=data, columns=colums
+ )
}
)
@@ -189,8 +206,14 @@ class LogReconstuctedImages(Callback):
self, trainer: Trainer, pl_module: LightningModule
) -> None:
"""Logs predictions on validation epoch end."""
- self._log_reconstruction(stage="val", trainer=trainer, pl_module=pl_module)
+ dataloader = trainer.datamodule.val_dataloader()
+ self._log_reconstruction(
+ stage="val", trainer=trainer, pl_module=pl_module, dataloader=dataloader
+ )
def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
"""Logs predictions on train epoch end."""
- self._log_reconstruction(stage="test", trainer=trainer, pl_module=pl_module)
+ dataloader = trainer.datamodule.test_dataloader()
+ self._log_reconstruction(
+ stage="test", trainer=trainer, pl_module=pl_module, dataloader=dataloader
+ )
diff --git a/training/conf/callbacks/wandb_image_reconstructions.yaml b/training/conf/callbacks/wandb_image_reconstructions.yaml
index e69de29..6cc4ada 100644
--- a/training/conf/callbacks/wandb_image_reconstructions.yaml
+++ b/training/conf/callbacks/wandb_image_reconstructions.yaml
@@ -0,0 +1,3 @@
+log_image_reconstruction:
+ _target_: callbacks.wandb_callbacks.LogReconstuctedImages
+ num_samples: 8
diff --git a/training/conf/callbacks/wandb_vae.yaml b/training/conf/callbacks/wandb_vae.yaml
new file mode 100644
index 0000000..609a8e8
--- /dev/null
+++ b/training/conf/callbacks/wandb_vae.yaml
@@ -0,0 +1,6 @@
+defaults:
+ - default
+ - wandb_watch
+ - wandb_code
+ - wandb_checkpoints
+ - wandb_image_reconstructions
diff --git a/training/conf/config.yaml b/training/conf/config.yaml
index 782bcbb..6b74502 100644
--- a/training/conf/config.yaml
+++ b/training/conf/config.yaml
@@ -1,3 +1,5 @@
+# @package _global_
+
defaults:
- callbacks: wandb_ocr
- criterion: label_smoothing
diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml
new file mode 100644
index 0000000..13e5f34
--- /dev/null
+++ b/training/conf/experiment/vqvae.yaml
@@ -0,0 +1,20 @@
+# @package _global_
+
+defaults:
+ - override /network: vqvae
+ - override /criterion: mse
+ - override /model: lit_vqvae
+ - override /callbacks: wandb_vae
+
+trainer:
+ max_epochs: 64
+
+datamodule:
+ batch_size: 32
+
+lr_scheduler:
+ epochs: 64
+ steps_per_epoch: 624
+
+optimizer:
+ lr: 1.0e-2
diff --git a/training/conf/experiment/vqvae_experiment.yaml b/training/conf/experiment/vqvae_experiment.yaml
deleted file mode 100644
index 0858c3d..0000000
--- a/training/conf/experiment/vqvae_experiment.yaml
+++ /dev/null
@@ -1,13 +0,0 @@
-defaults:
- - override /network: vqvae
- - override /criterion: mse
- - override /optimizer: madgrad
- - override /lr_scheduler: one_cycle
- - override /model: lit_vqvae
- - override /dataset: iam_extended_paragraphs
- - override /trainer: default
- - override /callbacks:
- - wandb
-
-load_checkpoint: null
-logging: INFO
diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml
index b337fe6..8837573 100644
--- a/training/conf/model/lit_vqvae.yaml
+++ b/training/conf/model/lit_vqvae.yaml
@@ -1,2 +1,4 @@
_target_: text_recognizer.models.vqvae.VQVAELitModel
-mapping: sentence_piece
+interval: step
+monitor: val/loss
+latent_loss_weight: 0.25
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
index f76e892..d3a3b0f 100644
--- a/training/conf/network/conv_transformer.yaml
+++ b/training/conf/network/conv_transformer.yaml
@@ -4,7 +4,7 @@ defaults:
_target_: text_recognizer.networks.conv_transformer.ConvTransformer
input_dims: [1, 576, 640]
-hidden_dim: 96
+hidden_dim: 128
dropout_rate: 0.2
num_classes: 1006
pad_index: 1002
diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml
index eb80f64..c326c04 100644
--- a/training/conf/network/decoder/transformer_decoder.yaml
+++ b/training/conf/network/decoder/transformer_decoder.yaml
@@ -2,12 +2,12 @@ defaults:
- rotary_emb: null
_target_: text_recognizer.networks.transformer.Decoder
-dim: 96
+dim: 128
depth: 2
num_heads: 8
attn_fn: text_recognizer.networks.transformer.attention.Attention
attn_kwargs:
- dim_head: 16
+ dim_head: 64
dropout_rate: 0.2
norm_fn: torch.nn.LayerNorm
ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml
index 22eebf8..5a5c066 100644
--- a/training/conf/network/vqvae.yaml
+++ b/training/conf/network/vqvae.yaml
@@ -1,13 +1,8 @@
-type: VQVAE
-args:
- in_channels: 1
- channels: [64, 96]
- kernel_sizes: [4, 4]
- strides: [2, 2]
- num_residual_layers: 2
- embedding_dim: 64
- num_embeddings: 256
- upsampling: null
- beta: 0.25
- activation: leaky_relu
- dropout_rate: 0.2
+_target_: text_recognizer.networks.vqvae.VQVAE
+in_channels: 1
+res_channels: 32
+num_residual_layers: 2
+embedding_dim: 64
+num_embeddings: 512
+decay: 0.99
+activation: mish