diff options
Diffstat (limited to 'notebooks/00-testing-stuff-out.ipynb')
-rw-r--r-- | notebooks/00-testing-stuff-out.ipynb | 1614 |
1 files changed, 246 insertions, 1368 deletions
diff --git a/notebooks/00-testing-stuff-out.ipynb b/notebooks/00-testing-stuff-out.ipynb index 4c93501..d4840ef 100644 --- a/notebooks/00-testing-stuff-out.ipynb +++ b/notebooks/00-testing-stuff-out.ipynb @@ -25,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -34,16 +34,16 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ - "path = \"../training/experiments/image_transformer.yaml\"" + "path = \"../training/configs/vqvae.yaml\"" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 74, "metadata": {}, "outputs": [], "source": [ @@ -52,42 +52,57 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 75, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ + "seed: 4711\n", "network:\n", + " desc: Configuration of the PyTorch neural network.\n", " type: ImageTransformer\n", " args:\n", - " input_shape: None\n", - " output_shape: None\n", - " encoder:\n", - " type: None\n", - " args: None\n", - " mapping: sentence_piece\n", - " num_decoder_layers: 4\n", - " hidden_dim: 256\n", - " num_heads: 4\n", - " expansion_dim: 1024\n", - " dropout_rate: 0.1\n", - " transformer_activation: glu\n", + " in_channels: 1\n", + " channels:\n", + " - 128\n", + " - 64\n", + " - 32\n", + " kernel_sizes:\n", + " - 4\n", + " - 4\n", + " - 4\n", + " strides:\n", + " - 2\n", + " - 2\n", + " - 2\n", + " num_residual_layers: 4\n", + " embedding_dim: 128\n", + " num_embeddings: 1024\n", + " upsampling: null\n", + " beta: 6.6\n", + " activation: leaky_relu\n", + " dropout_rate: 0.25\n", "model:\n", + " desc: Configuration of the PyTorch Lightning model.\n", " type: LitTransformerModel\n", " args:\n", " optimizer:\n", " type: MADGRAD\n", " args:\n", - " lr: 0.01\n", + " lr: 0.001\n", " momentum: 0.9\n", " weight_decay: 0\n", " eps: 1.0e-06\n", " lr_scheduler:\n", - " type: CosineAnnealingLR\n", + " type: OneCycle\n", " args:\n", - " T_max: 512\n", + " interval: step\n", + " max_lr: 0.001\n", + " three_phase: true\n", + " epochs: 512\n", + " steps_per_epoch: 1246\n", " criterion:\n", " type: CrossEntropyLoss\n", " args:\n", @@ -97,6 +112,7 @@ " monitor: val_loss\n", " mapping: sentence_piece\n", "data:\n", + " desc: Configuration of the training/test data.\n", " type: IAMExtendedParagraphs\n", " args:\n", " batch_size: 16\n", @@ -108,22 +124,35 @@ " args:\n", " monitor: val_loss\n", " mode: min\n", + " save_last: true\n", + "- type: StochasticWeightAveraging\n", + " args:\n", + " swa_epoch_start: 0.8\n", + " swa_lrs: 0.05\n", + " annealing_epochs: 10\n", + " annealing_strategy: cos\n", + " device: null\n", + "- type: LearningRateMonitor\n", + " args:\n", + " logging_interval: step\n", "- type: EarlyStopping\n", " args:\n", " monitor: val_loss\n", " mode: min\n", " patience: 10\n", "trainer:\n", + " desc: Configuration of the PyTorch Lightning Trainer.\n", " args:\n", " stochastic_weight_avg: true\n", - " auto_scale_batch_size: power\n", + " auto_scale_batch_size: binsearch\n", " gradient_clip_val: 0\n", " fast_dev_run: false\n", " gpus: 1\n", " precision: 16\n", - " max_epocs: 512\n", + " max_epochs: 512\n", " terminate_on_nan: true\n", " weights_summary: true\n", + "load_checkpoint: null\n", "\n" ] } @@ -134,1438 +163,287 @@ }, { "cell_type": "code", - "execution_count": 1, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "' tes\".t '" - ] - }, - "execution_count": 1, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "\"\"\" tes\".t \"\"\"" - ] - }, - { - "cell_type": "code", - "execution_count": 2, + "execution_count": 76, "metadata": {}, "outputs": [], "source": [ - "from text_recognizer.networks import CNN, TDS2d" + "from text_recognizer.networks import VQVAE" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 78, "metadata": {}, "outputs": [], "source": [ - "tds2d = TDS2d(**{\n", - " \"depth\" : 4,\n", - " \"tds_groups\" : [\n", - " { \"channels\" : 4, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n", - " { \"channels\" : 32, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n", - " { \"channels\" : 64, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n", - " { \"channels\" : 128, \"num_blocks\" : 3, \"stride\" : [2, 1] },\n", - " ],\n", - " \"kernel_size\" : [5, 7],\n", - " \"dropout_rate\" : 0.1\n", - " }, input_dim=32, output_dim=128)" + "vae = VQVAE(**conf.network.args)" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 79, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "TDS2d(\n", - " (tds): Sequential(\n", - " (0): Conv2d(1, 16, kernel_size=[5, 7], stride=[2, 2], padding=(2, 3))\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " (3): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " (4): TDSBlock2d(\n", - " (conv): Sequential(\n", - " (0): Conv3d(4, 4, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (mlp): Sequential(\n", - " (0): Linear(in_features=16, out_features=16, bias=True)\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " (3): Linear(in_features=16, out_features=16, bias=True)\n", - " (4): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (instance_norm): ModuleList(\n", - " (0): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " (1): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " )\n", + "VQVAE(\n", + " (encoder): Encoder(\n", + " (encoder): Sequential(\n", + " (0): Sequential(\n", + " (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", + " )\n", + " (1): Dropout(p=0.25, inplace=False)\n", + " (2): Sequential(\n", + " (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", + " )\n", + " (3): Dropout(p=0.25, inplace=False)\n", + " (4): Sequential(\n", + " (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", + " )\n", + " (5): Dropout(p=0.25, inplace=False)\n", + " (6): _ResidualBlock(\n", + " (block): Sequential(\n", + " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): ReLU(inplace=True)\n", + " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): Dropout(p=0.25, inplace=False)\n", + " )\n", + " )\n", + " (7): _ResidualBlock(\n", + " (block): Sequential(\n", + " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): ReLU(inplace=True)\n", + " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): Dropout(p=0.25, inplace=False)\n", + " )\n", + " )\n", + " (8): _ResidualBlock(\n", + " (block): Sequential(\n", + " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): ReLU(inplace=True)\n", + " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): Dropout(p=0.25, inplace=False)\n", + " )\n", + " )\n", + " (9): _ResidualBlock(\n", + " (block): Sequential(\n", + " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): ReLU(inplace=True)\n", + " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): Dropout(p=0.25, inplace=False)\n", + " )\n", + " )\n", + " (10): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n", " )\n", - " (5): TDSBlock2d(\n", - " (conv): Sequential(\n", - " (0): Conv3d(4, 4, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (mlp): Sequential(\n", - " (0): Linear(in_features=16, out_features=16, bias=True)\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " (3): Linear(in_features=16, out_features=16, bias=True)\n", - " (4): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (instance_norm): ModuleList(\n", - " (0): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " (1): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " )\n", - " )\n", - " (6): TDSBlock2d(\n", - " (conv): Sequential(\n", - " (0): Conv3d(4, 4, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (mlp): Sequential(\n", - " (0): Linear(in_features=16, out_features=16, bias=True)\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " (3): Linear(in_features=16, out_features=16, bias=True)\n", - " (4): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (instance_norm): ModuleList(\n", - " (0): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " (1): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " )\n", - " )\n", - " (7): Conv2d(16, 128, kernel_size=[5, 7], stride=[2, 2], padding=(2, 3))\n", - " (8): ReLU(inplace=True)\n", - " (9): Dropout(p=0.1, inplace=False)\n", - " (10): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " (11): TDSBlock2d(\n", - " (conv): Sequential(\n", - " (0): Conv3d(32, 32, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (mlp): Sequential(\n", - " (0): Linear(in_features=128, out_features=128, bias=True)\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " (3): Linear(in_features=128, out_features=128, bias=True)\n", - " (4): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (instance_norm): ModuleList(\n", - " (0): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " )\n", - " )\n", - " (12): TDSBlock2d(\n", - " (conv): Sequential(\n", - " (0): Conv3d(32, 32, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (mlp): Sequential(\n", - " (0): Linear(in_features=128, out_features=128, bias=True)\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " (3): Linear(in_features=128, out_features=128, bias=True)\n", - " (4): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (instance_norm): ModuleList(\n", - " (0): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " )\n", + " (vector_quantizer): VectorQuantizer(\n", + " (embedding): Embedding(1024, 128)\n", " )\n", - " (13): TDSBlock2d(\n", - " (conv): Sequential(\n", - " (0): Conv3d(32, 32, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (mlp): Sequential(\n", - " (0): Linear(in_features=128, out_features=128, bias=True)\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " (3): Linear(in_features=128, out_features=128, bias=True)\n", - " (4): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (instance_norm): ModuleList(\n", - " (0): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " )\n", - " )\n", - " (14): Conv2d(128, 256, kernel_size=[5, 7], stride=[2, 2], padding=(2, 3))\n", - " (15): ReLU(inplace=True)\n", - " (16): Dropout(p=0.1, inplace=False)\n", - " (17): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " (18): TDSBlock2d(\n", - " (conv): Sequential(\n", - " (0): Conv3d(64, 64, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (mlp): Sequential(\n", - " (0): Linear(in_features=256, out_features=256, bias=True)\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " (3): Linear(in_features=256, out_features=256, bias=True)\n", - " (4): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (instance_norm): ModuleList(\n", - " (0): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " )\n", - " )\n", - " (19): TDSBlock2d(\n", - " (conv): Sequential(\n", - " (0): Conv3d(64, 64, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (mlp): Sequential(\n", - " (0): Linear(in_features=256, out_features=256, bias=True)\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " (3): Linear(in_features=256, out_features=256, bias=True)\n", - " (4): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (instance_norm): ModuleList(\n", - " (0): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " )\n", - " )\n", - " (20): TDSBlock2d(\n", - " (conv): Sequential(\n", - " (0): Conv3d(64, 64, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (mlp): Sequential(\n", - " (0): Linear(in_features=256, out_features=256, bias=True)\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " (3): Linear(in_features=256, out_features=256, bias=True)\n", - " (4): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (instance_norm): ModuleList(\n", - " (0): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " )\n", - " )\n", - " (21): Conv2d(256, 512, kernel_size=[5, 7], stride=[2, 1], padding=(2, 3))\n", - " (22): ReLU(inplace=True)\n", - " (23): Dropout(p=0.1, inplace=False)\n", - " (24): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " (25): TDSBlock2d(\n", - " (conv): Sequential(\n", - " (0): Conv3d(128, 128, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (mlp): Sequential(\n", - " (0): Linear(in_features=512, out_features=512, bias=True)\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " (3): Linear(in_features=512, out_features=512, bias=True)\n", - " (4): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (instance_norm): ModuleList(\n", - " (0): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " )\n", + " (decoder): Decoder(\n", + " (res_block): Sequential(\n", + " (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): _ResidualBlock(\n", + " (block): Sequential(\n", + " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): ReLU(inplace=True)\n", + " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): Dropout(p=0.25, inplace=False)\n", + " )\n", + " )\n", + " (2): _ResidualBlock(\n", + " (block): Sequential(\n", + " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): ReLU(inplace=True)\n", + " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): Dropout(p=0.25, inplace=False)\n", + " )\n", + " )\n", + " (3): _ResidualBlock(\n", + " (block): Sequential(\n", + " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): ReLU(inplace=True)\n", + " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): Dropout(p=0.25, inplace=False)\n", + " )\n", + " )\n", + " (4): _ResidualBlock(\n", + " (block): Sequential(\n", + " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): ReLU(inplace=True)\n", + " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): Dropout(p=0.25, inplace=False)\n", + " )\n", " )\n", " )\n", - " (26): TDSBlock2d(\n", - " (conv): Sequential(\n", - " (0): Conv3d(128, 128, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (mlp): Sequential(\n", - " (0): Linear(in_features=512, out_features=512, bias=True)\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " (3): Linear(in_features=512, out_features=512, bias=True)\n", - " (4): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (instance_norm): ModuleList(\n", - " (0): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " )\n", + " (upsampling_block): Sequential(\n", + " (0): Sequential(\n", + " (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", + " )\n", + " (1): Dropout(p=0.25, inplace=False)\n", + " (2): Sequential(\n", + " (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", + " )\n", + " (3): Dropout(p=0.25, inplace=False)\n", + " (4): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (5): Tanh()\n", " )\n", - " (27): TDSBlock2d(\n", - " (conv): Sequential(\n", - " (0): Conv3d(128, 128, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (mlp): Sequential(\n", - " (0): Linear(in_features=512, out_features=512, bias=True)\n", - " (1): ReLU(inplace=True)\n", - " (2): Dropout(p=0.1, inplace=False)\n", - " (3): Linear(in_features=512, out_features=512, bias=True)\n", - " (4): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (instance_norm): ModuleList(\n", - " (0): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", - " (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " (decoder): Sequential(\n", + " (0): Sequential(\n", + " (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n", + " (1): _ResidualBlock(\n", + " (block): Sequential(\n", + " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): ReLU(inplace=True)\n", + " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): Dropout(p=0.25, inplace=False)\n", + " )\n", + " )\n", + " (2): _ResidualBlock(\n", + " (block): Sequential(\n", + " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): ReLU(inplace=True)\n", + " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): Dropout(p=0.25, inplace=False)\n", + " )\n", + " )\n", + " (3): _ResidualBlock(\n", + " (block): Sequential(\n", + " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): ReLU(inplace=True)\n", + " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): Dropout(p=0.25, inplace=False)\n", + " )\n", + " )\n", + " (4): _ResidualBlock(\n", + " (block): Sequential(\n", + " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", + " (1): ReLU(inplace=True)\n", + " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", + " (3): Dropout(p=0.25, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " (1): Sequential(\n", + " (0): Sequential(\n", + " (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", + " )\n", + " (1): Dropout(p=0.25, inplace=False)\n", + " (2): Sequential(\n", + " (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", + " )\n", + " (3): Dropout(p=0.25, inplace=False)\n", + " (4): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", + " (5): Tanh()\n", " )\n", " )\n", " )\n", - " (fc): Linear(in_features=1024, out_features=128, bias=True)\n", ")" ] }, - "execution_count": 4, + "execution_count": 79, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "tds2d" + "vae" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 80, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "===============================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "===============================================================================================\n", - "├─Sequential: 1-1 [-1, 512, 2, 119] --\n", - "| └─Conv2d: 2-1 [-1, 16, 14, 476] 576\n", - "| └─ReLU: 2-2 [-1, 16, 14, 476] --\n", - "| └─Dropout: 2-3 [-1, 16, 14, 476] --\n", - "| └─InstanceNorm2d: 2-4 [-1, 16, 14, 476] 32\n", - "| └─TDSBlock2d: 2-5 [-1, 16, 14, 476] --\n", - "| | └─Sequential: 3-1 [-1, 4, 4, 14, 476] 564\n", - "| | └─Sequential: 3-2 [-1, 476, 14, 16] 544\n", - "| └─TDSBlock2d: 2-6 [-1, 16, 14, 476] --\n", - "| | └─Sequential: 3-3 [-1, 4, 4, 14, 476] 564\n", - "| | └─Sequential: 3-4 [-1, 476, 14, 16] 544\n", - "| └─TDSBlock2d: 2-7 [-1, 16, 14, 476] --\n", - "| | └─Sequential: 3-5 [-1, 4, 4, 14, 476] 564\n", - "| | └─Sequential: 3-6 [-1, 476, 14, 16] 544\n", - "| └─Conv2d: 2-8 [-1, 128, 7, 238] 71,808\n", - "| └─ReLU: 2-9 [-1, 128, 7, 238] --\n", - "| └─Dropout: 2-10 [-1, 128, 7, 238] --\n", - "| └─InstanceNorm2d: 2-11 [-1, 128, 7, 238] 256\n", - "| └─TDSBlock2d: 2-12 [-1, 128, 7, 238] --\n", - "| | └─Sequential: 3-7 [-1, 32, 4, 7, 238] 35,872\n", - "| | └─Sequential: 3-8 [-1, 238, 7, 128] 33,024\n", - "| └─TDSBlock2d: 2-13 [-1, 128, 7, 238] --\n", - "| | └─Sequential: 3-9 [-1, 32, 4, 7, 238] 35,872\n", - "| | └─Sequential: 3-10 [-1, 238, 7, 128] 33,024\n", - "| └─TDSBlock2d: 2-14 [-1, 128, 7, 238] --\n", - "| | └─Sequential: 3-11 [-1, 32, 4, 7, 238] 35,872\n", - "| | └─Sequential: 3-12 [-1, 238, 7, 128] 33,024\n", - "| └─Conv2d: 2-15 [-1, 256, 4, 119] 1,147,136\n", - "| └─ReLU: 2-16 [-1, 256, 4, 119] --\n", - "| └─Dropout: 2-17 [-1, 256, 4, 119] --\n", - "| └─InstanceNorm2d: 2-18 [-1, 256, 4, 119] 512\n", - "| └─TDSBlock2d: 2-19 [-1, 256, 4, 119] --\n", - "| | └─Sequential: 3-13 [-1, 64, 4, 4, 119] 143,424\n", - "| | └─Sequential: 3-14 [-1, 119, 4, 256] 131,584\n", - "| └─TDSBlock2d: 2-20 [-1, 256, 4, 119] --\n", - "| | └─Sequential: 3-15 [-1, 64, 4, 4, 119] 143,424\n", - "| | └─Sequential: 3-16 [-1, 119, 4, 256] 131,584\n", - "| └─TDSBlock2d: 2-21 [-1, 256, 4, 119] --\n", - "| | └─Sequential: 3-17 [-1, 64, 4, 4, 119] 143,424\n", - "| | └─Sequential: 3-18 [-1, 119, 4, 256] 131,584\n", - "| └─Conv2d: 2-22 [-1, 512, 2, 119] 4,588,032\n", - "| └─ReLU: 2-23 [-1, 512, 2, 119] --\n", - "| └─Dropout: 2-24 [-1, 512, 2, 119] --\n", - "| └─InstanceNorm2d: 2-25 [-1, 512, 2, 119] 1,024\n", - "| └─TDSBlock2d: 2-26 [-1, 512, 2, 119] --\n", - "| | └─Sequential: 3-19 [-1, 128, 4, 2, 119] 573,568\n", - "| | └─Sequential: 3-20 [-1, 119, 2, 512] 525,312\n", - "| └─TDSBlock2d: 2-27 [-1, 512, 2, 119] --\n", - "| | └─Sequential: 3-21 [-1, 128, 4, 2, 119] 573,568\n", - "| | └─Sequential: 3-22 [-1, 119, 2, 512] 525,312\n", - "| └─TDSBlock2d: 2-28 [-1, 512, 2, 119] --\n", - "| | └─Sequential: 3-23 [-1, 128, 4, 2, 119] 573,568\n", - "| | └─Sequential: 3-24 [-1, 119, 2, 512] 525,312\n", - "├─Linear: 1-2 [-1, 119, 128] 131,200\n", - "===============================================================================================\n", - "Total params: 10,272,252\n", - "Trainable params: 10,272,252\n", - "Non-trainable params: 0\n", - "Total mult-adds (G): 5.00\n", - "===============================================================================================\n", - "Input size (MB): 0.10\n", - "Forward/backward pass size (MB): 73.21\n", - "Params size (MB): 39.19\n", - "Estimated Total Size (MB): 112.50\n", - "===============================================================================================\n" - ] - }, - { "data": { "text/plain": [ - "===============================================================================================\n", - "Layer (type:depth-idx) Output Shape Param #\n", - "===============================================================================================\n", - "├─Sequential: 1-1 [-1, 512, 2, 119] --\n", - "| └─Conv2d: 2-1 [-1, 16, 14, 476] 576\n", - "| └─ReLU: 2-2 [-1, 16, 14, 476] --\n", - "| └─Dropout: 2-3 [-1, 16, 14, 476] --\n", - "| └─InstanceNorm2d: 2-4 [-1, 16, 14, 476] 32\n", - "| └─TDSBlock2d: 2-5 [-1, 16, 14, 476] --\n", - "| | └─Sequential: 3-1 [-1, 4, 4, 14, 476] 564\n", - "| | └─Sequential: 3-2 [-1, 476, 14, 16] 544\n", - "| └─TDSBlock2d: 2-6 [-1, 16, 14, 476] --\n", - "| | └─Sequential: 3-3 [-1, 4, 4, 14, 476] 564\n", - "| | └─Sequential: 3-4 [-1, 476, 14, 16] 544\n", - "| └─TDSBlock2d: 2-7 [-1, 16, 14, 476] --\n", - "| | └─Sequential: 3-5 [-1, 4, 4, 14, 476] 564\n", - "| | └─Sequential: 3-6 [-1, 476, 14, 16] 544\n", - "| └─Conv2d: 2-8 [-1, 128, 7, 238] 71,808\n", - "| └─ReLU: 2-9 [-1, 128, 7, 238] --\n", - "| └─Dropout: 2-10 [-1, 128, 7, 238] --\n", - "| └─InstanceNorm2d: 2-11 [-1, 128, 7, 238] 256\n", - "| └─TDSBlock2d: 2-12 [-1, 128, 7, 238] --\n", - "| | └─Sequential: 3-7 [-1, 32, 4, 7, 238] 35,872\n", - "| | └─Sequential: 3-8 [-1, 238, 7, 128] 33,024\n", - "| └─TDSBlock2d: 2-13 [-1, 128, 7, 238] --\n", - "| | └─Sequential: 3-9 [-1, 32, 4, 7, 238] 35,872\n", - "| | └─Sequential: 3-10 [-1, 238, 7, 128] 33,024\n", - "| └─TDSBlock2d: 2-14 [-1, 128, 7, 238] --\n", - "| | └─Sequential: 3-11 [-1, 32, 4, 7, 238] 35,872\n", - "| | └─Sequential: 3-12 [-1, 238, 7, 128] 33,024\n", - "| └─Conv2d: 2-15 [-1, 256, 4, 119] 1,147,136\n", - "| └─ReLU: 2-16 [-1, 256, 4, 119] --\n", - "| └─Dropout: 2-17 [-1, 256, 4, 119] --\n", - "| └─InstanceNorm2d: 2-18 [-1, 256, 4, 119] 512\n", - "| └─TDSBlock2d: 2-19 [-1, 256, 4, 119] --\n", - "| | └─Sequential: 3-13 [-1, 64, 4, 4, 119] 143,424\n", - "| | └─Sequential: 3-14 [-1, 119, 4, 256] 131,584\n", - "| └─TDSBlock2d: 2-20 [-1, 256, 4, 119] --\n", - "| | └─Sequential: 3-15 [-1, 64, 4, 4, 119] 143,424\n", - "| | └─Sequential: 3-16 [-1, 119, 4, 256] 131,584\n", - "| └─TDSBlock2d: 2-21 [-1, 256, 4, 119] --\n", - "| | └─Sequential: 3-17 [-1, 64, 4, 4, 119] 143,424\n", - "| | └─Sequential: 3-18 [-1, 119, 4, 256] 131,584\n", - "| └─Conv2d: 2-22 [-1, 512, 2, 119] 4,588,032\n", - "| └─ReLU: 2-23 [-1, 512, 2, 119] --\n", - "| └─Dropout: 2-24 [-1, 512, 2, 119] --\n", - "| └─InstanceNorm2d: 2-25 [-1, 512, 2, 119] 1,024\n", - "| └─TDSBlock2d: 2-26 [-1, 512, 2, 119] --\n", - "| | └─Sequential: 3-19 [-1, 128, 4, 2, 119] 573,568\n", - "| | └─Sequential: 3-20 [-1, 119, 2, 512] 525,312\n", - "| └─TDSBlock2d: 2-27 [-1, 512, 2, 119] --\n", - "| | └─Sequential: 3-21 [-1, 128, 4, 2, 119] 573,568\n", - "| | └─Sequential: 3-22 [-1, 119, 2, 512] 525,312\n", - "| └─TDSBlock2d: 2-28 [-1, 512, 2, 119] --\n", - "| | └─Sequential: 3-23 [-1, 128, 4, 2, 119] 573,568\n", - "| | └─Sequential: 3-24 [-1, 119, 2, 512] 525,312\n", - "├─Linear: 1-2 [-1, 119, 128] 131,200\n", - "===============================================================================================\n", - "Total params: 10,272,252\n", - "Trainable params: 10,272,252\n", - "Non-trainable params: 0\n", - "Total mult-adds (G): 5.00\n", - "===============================================================================================\n", - "Input size (MB): 0.10\n", - "Forward/backward pass size (MB): 73.21\n", - "Params size (MB): 39.19\n", - "Estimated Total Size (MB): 112.50\n", - "===============================================================================================" + "tensor([1.])" ] }, - "execution_count": 5, + "execution_count": 80, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "summary(tds2d, (1, 28, 952), device=\"cpu\", depth=3)" + "torch.Tensor([1])" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 81, "metadata": {}, "outputs": [], "source": [ - "t = torch.randn(2,1, 28, 952)" + "datum = torch.randn([2, 1, 576, 640])" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 82, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "torch.Size([2, 119, 128])" + "torch.Size([2, 1, 576, 640])" ] }, - "execution_count": 7, + "execution_count": 82, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "tds2d(t).shape" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "cnn = CNN().cuda()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "i = nn.Sequential(nn.Conv2d(1,1,1,1))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "nn.Sequential(i,i)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "cnn(t).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks.vqvae import Encoder, Decoder, VQVAE" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "vqvae = VQVAE(1, [32, 128, 128, 256], [4, 4, 4, 4], [2, 2, [1, 2], [1, 2]], 2, 32, 256, [[6, 119], [7, 238]])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t = torch.randn(2, 1, 28, 952)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "x, l = vqvae(t)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "5 * 59 / 10" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "x.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "summary(vqvae, (1, 28, 952), device=\"cpu\", depth=3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "up = nn.Upsample([4, 59])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "up(tt).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tt.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class GEGLU(nn.Module):\n", - " def __init__(self, dim_in, dim_out):\n", - " super().__init__()\n", - " self.proj = nn.Linear(dim_in, dim_out * 2)\n", - "\n", - " def forward(self, x):\n", - " x, gate = self.proj(x).chunk(2, dim = -1)\n", - " return x * F.gelu(gate)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "e = GEGLU(256, 2048)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "e(t).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "emb = nn.Embedding(56, 256)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with torch.no_grad():\n", - " e = emb(torch.Tensor([55]).long())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from einops import repeat" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ee = repeat(e, \"() n -> b n\", b=16)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "emb.device" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ee" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ee.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t = torch.randn(16, 10, 256)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t = torch.cat((ee.unsqueeze(1), t, ee.unsqueeze(1)), dim=1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "e.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks.residual_network import IdentityBlock, ResidualBlock, BasicBlock, BottleNeckBlock, ResidualLayer, ResidualNetwork, ResidualNetworkEncoder" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks import WideResidualNetwork" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wr = WideResidualNetwork(\n", - " in_channels= 1,\n", - " num_classes= 80,\n", - " in_planes=64,\n", - " depth=10,\n", - " num_layers=4,\n", - " width_factor=2,\n", - " num_stages=[64, 128, 256, 256],\n", - " dropout_rate= 0.1,\n", - " activation= \"SELU\",\n", - " use_decoder= False,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from torchsummary import summary" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "backbone = ResidualNetworkEncoder(1, [64, 65, 66, 67, 68], [2, 2, 2, 2, 2])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "summary(backbone, (1, 28, 952), device=\"cpu\", depth=3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - " backbone = nn.Sequential(\n", - " *list(wr.children())[:][:]\n", - " )\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "backbone" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "summary(wr, (1, 28, 952), device=\"cpu\", depth=3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "a = torch.rand(1, 1, 28, 952)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "b = wr(a)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from einops import rearrange" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "b = rearrange(b, \"b c h w -> b w c h\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "c = nn.AdaptiveAvgPool2d((None, 1))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "d = c(b)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "d.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "d.squeeze(3).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "b.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from torch import nn" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "32 + 64" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "3 * 112" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "col_embed = nn.Parameter(torch.rand(1000, 256 // 2))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "W, H = 196, 4" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "col_embed[:W].unsqueeze(0).repeat(H, 1, 1).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "col_embed[:H].unsqueeze(1).repeat(1, W, 1).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - " torch.cat(\n", - " [\n", - " col_embed[:W].unsqueeze(0).repeat(H, 1, 1),\n", - " col_embed[:H].unsqueeze(1).repeat(1, W, 1),\n", - " ],\n", - " dim=-1,\n", - " ).unsqueeze(0).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "4 * 196" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "target = torch.tensor([1,1,12,1,1,1,1,1,9,9,9,9,9,9])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "torch.nonzero(target == 9, as_tuple=False)[0].item()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "target[:9]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "np.inf" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks.transformer.positional_encoding import PositionalEncoding" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plt.figure(figsize=(15, 5))\n", - "pe = PositionalEncoding(20, 0)\n", - "y = pe.forward(torch.zeros(1, 100, 20))\n", - "plt.plot(np.arange(100), y[0, :, 4:8].data.numpy())\n", - "plt.legend([\"dim %d\"%p for p in [4,5,6,7]])\n", - "None" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks.densenet import DenseNet,_DenseLayer,_DenseBlock" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dnet = DenseNet(12, (6, 12, 10), 1, 24, 80, 4, 0, True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "216 / 8" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "summary(dnet, (1, 28, 952), device=\"cpu\", depth=3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - " backbone = nn.Sequential(\n", - " *list(dnet.children())[:][:-4]\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "backbone" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks import WideResidualNetwork" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "w = WideResidualNetwork(\n", - " in_channels = 1,\n", - " in_planes = 32,\n", - " num_classes = 80,\n", - " depth = 10,\n", - " width_factor = 1,\n", - " dropout_rate = 0.0,\n", - " num_layers = 5,\n", - " activation = \"relu\",\n", - " use_decoder = False,)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "summary(w, (1, 28, 952), device=\"cpu\", depth=2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "sz= 5" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mask = torch.triu(torch.ones(sz, sz), 1)\n", - "mask = mask.masked_fill(mask==1, float('-inf'))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "h = torch.rand(1, 256, 10, 10)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "h.flatten(2).permute(2, 0, 1).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "h.flatten(2).permute(2, 0, 1).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mask\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pred = torch.Tensor([1,21,2,45,31, 81, 1, 79, 79, 79, 2,1,1,1,1, 81, 1, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()\n", - "target = torch.Tensor([1,1,1,1,1, 81, 79, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mask = (target != 79)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mask" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pred * mask" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "target * mask" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.models.metrics import accuracy" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pad_indcies = torch.nonzero(target == 79, as_tuple=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t1 = torch.nonzero(target == 81, as_tuple=False).squeeze(1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "target.shape[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t2 = torch.arange(10, target.shape[0] + 1, 10)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t2" + "datum.shape" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 85, "metadata": {}, - "outputs": [], - "source": [ - "for start, stop in zip(t1, t2):\n", - " pred[start+1:stop] = 79" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pred" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "[pred[start+1:stop] = 79 for start, stop in zip(t1, t2)]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "pad_indcies" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pred[pad_indcies:pad_indcies] = 79" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pred.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "target.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "accuracy(pred, target)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 128, 72, 80])" + ] + }, + "execution_count": 85, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "acc = (pred == target).sum().float() / target.shape[0]" + "vae.encoder(datum)[0].shape" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 87, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 1, 576, 640])" + ] + }, + "execution_count": 87, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "acc" + "vae(datum)[0].shape" ] }, { |