From 9426cc794d8c28a65bbbf5ae5466a0a343078558 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 25 Apr 2021 23:32:50 +0200 Subject: Efficient net and non working transformer model. --- notebooks/00-testing-stuff-out.ipynb | 555 ++++++++---------------------- notebooks/03-look-at-iam-paragraphs.ipynb | 63 +++- 2 files changed, 201 insertions(+), 417 deletions(-) (limited to 'notebooks') diff --git a/notebooks/00-testing-stuff-out.ipynb b/notebooks/00-testing-stuff-out.ipynb index e6cf099..7c7b3a6 100644 --- a/notebooks/00-testing-stuff-out.ipynb +++ b/notebooks/00-testing-stuff-out.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 4, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -25,7 +25,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -34,7 +34,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -43,7 +43,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -52,110 +52,16 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "seed: 4711\n", - "network:\n", - " desc: Configuration of the PyTorch neural network.\n", - " type: VQVAE\n", - " args:\n", - " in_channels: 1\n", - " channels:\n", - " - 32\n", - " - 64\n", - " - 96\n", - " - 96\n", - " - 128\n", - " kernel_sizes:\n", - " - 4\n", - " - 4\n", - " - 4\n", - " - 4\n", - " - 4\n", - " strides:\n", - " - 2\n", - " - 2\n", - " - 2\n", - " - 2\n", - " - 2\n", - " num_residual_layers: 2\n", - " embedding_dim: 128\n", - " num_embeddings: 1024\n", - " upsampling: null\n", - " beta: 0.25\n", - " activation: leaky_relu\n", - " dropout_rate: 0.1\n", - "model:\n", - " desc: Configuration of the PyTorch Lightning model.\n", - " type: LitVQVAEModel\n", - " args:\n", - " optimizer:\n", - " type: MADGRAD\n", - " args:\n", - " lr: 0.001\n", - " momentum: 0.9\n", - " weight_decay: 0\n", - " eps: 1.0e-06\n", - " lr_scheduler:\n", - " type: OneCycleLR\n", - " args:\n", - " interval: step\n", - " max_lr: 0.001\n", - " three_phase: true\n", - " epochs: 1024\n", - " steps_per_epoch: 317\n", - " criterion:\n", - " type: MSELoss\n", - " args:\n", - " reduction: mean\n", - " monitor: val_loss\n", - " mapping: sentence_piece\n", - "data:\n", - " desc: Configuration of the training/test data.\n", - " type: IAMExtendedParagraphs\n", - " args:\n", - " batch_size: 64\n", - " num_workers: 12\n", - " train_fraction: 0.8\n", - " augment: true\n", - "callbacks:\n", - "- type: ModelCheckpoint\n", - " args:\n", - " monitor: val_loss\n", - " mode: min\n", - " save_last: true\n", - "- type: LearningRateMonitor\n", - " args:\n", - " logging_interval: step\n", - "trainer:\n", - " desc: Configuration of the PyTorch Lightning Trainer.\n", - " args:\n", - " stochastic_weight_avg: false\n", - " auto_scale_batch_size: binsearch\n", - " gradient_clip_val: 0\n", - " fast_dev_run: false\n", - " gpus: 1\n", - " precision: 16\n", - " max_epochs: 1024\n", - " terminate_on_nan: true\n", - " weights_summary: full\n", - "load_checkpoint: null\n", - "\n" - ] - } - ], + "outputs": [], "source": [ "print(OmegaConf.to_yaml(conf))" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -164,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -173,167 +79,16 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "VQVAE(\n", - " (encoder): Encoder(\n", - " (encoder): Sequential(\n", - " (0): Sequential(\n", - " (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (1): Dropout(p=0.1, inplace=False)\n", - " (2): Sequential(\n", - " (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " (4): Sequential(\n", - " (0): Conv2d(64, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (5): Dropout(p=0.1, inplace=False)\n", - " (6): Sequential(\n", - " (0): Conv2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (7): Dropout(p=0.1, inplace=False)\n", - " (8): Sequential(\n", - " (0): Conv2d(96, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (9): Dropout(p=0.1, inplace=False)\n", - " (10): _ResidualBlock(\n", - " (block): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (11): _ResidualBlock(\n", - " (block): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (12): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n", - " )\n", - " (vector_quantizer): VectorQuantizer(\n", - " (embedding): Embedding(1024, 128)\n", - " )\n", - " )\n", - " (decoder): Decoder(\n", - " (res_block): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): _ResidualBlock(\n", - " (block): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (2): _ResidualBlock(\n", - " (block): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " )\n", - " (upsampling_block): Sequential(\n", - " (0): Sequential(\n", - " (0): ConvTranspose2d(128, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (1): Dropout(p=0.1, inplace=False)\n", - " (2): Sequential(\n", - " (0): ConvTranspose2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " (4): Sequential(\n", - " (0): ConvTranspose2d(96, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (5): Dropout(p=0.1, inplace=False)\n", - " (6): Sequential(\n", - " (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (7): Dropout(p=0.1, inplace=False)\n", - " (8): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (9): Tanh()\n", - " )\n", - " (decoder): Sequential(\n", - " (0): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n", - " (1): _ResidualBlock(\n", - " (block): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " (2): _ResidualBlock(\n", - " (block): Sequential(\n", - " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n", - " (1): ReLU(inplace=True)\n", - " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " )\n", - " )\n", - " )\n", - " (1): Sequential(\n", - " (0): Sequential(\n", - " (0): ConvTranspose2d(128, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (1): Dropout(p=0.1, inplace=False)\n", - " (2): Sequential(\n", - " (0): ConvTranspose2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (3): Dropout(p=0.1, inplace=False)\n", - " (4): Sequential(\n", - " (0): ConvTranspose2d(96, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (5): Dropout(p=0.1, inplace=False)\n", - " (6): Sequential(\n", - " (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n", - " )\n", - " (7): Dropout(p=0.1, inplace=False)\n", - " (8): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n", - " (9): Tanh()\n", - " )\n", - " )\n", - " )\n", - ")" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "vae" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -342,275 +97,259 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "proj = nn.Conv2d(1, 32, kernel_size=16, stride=16)" + "vae.encoder(datum)[0].shape" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "x = proj(datum)" + "vae(datum)[0].shape" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 32, 36, 40])" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "x.shape" + "from text_recognizer.networks.backbones.efficientnet import EfficientNet" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "xx = x.flatten(2)" + "en = EfficientNet()" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 3, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 32, 1440])" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "xx.shape" + "datum = torch.randn([2, 1, 576, 640])" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 4, "metadata": {}, "outputs": [], "source": [ - "xxx = xx.transpose(1,2)" + "trg = torch.randint(0, 1000, [2, 682])" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 1440, 32])" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "xxx.shape" + "trg.shape" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ - "from einops import rearrange" + "datum = torch.randn([2, 1, 224, 224])" ] }, { "cell_type": "code", - "execution_count": 14, - "metadata": {}, + "execution_count": null, + "metadata": { + "scrolled": false + }, "outputs": [], "source": [ - "xxxx = rearrange(x, \"b c h w -> b ( h w ) c\")" + "en(datum).shape" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 5, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 1440, 32])" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "xxxx.shape" + "path = \"../training/configs/cnn_transformer.yaml\"" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ - " B, N, C = x.shape\n", - " H, W = size\n", - " assert N == 1 + H * W\n", - "\n", - " # Extract CLS token and image tokens.\n", - " cls_token, img_tokens = x[:, :1], x[:, 1:] # Shape: [B, 1, C], [B, H*W, C].\n", - " \n", - " # Depthwise convolution.\n", - " feat = img_tokens.transpose(1, 2).view(B, C, H, W)" + "conf = OmegaConf.load(path)" ] }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 7, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "torch.Size([2, 32, 36, 40])" - ] - }, - "execution_count": 22, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "seed: 4711\n", + "network:\n", + " desc: Configuration of the PyTorch neural network.\n", + " type: CNNTransformer\n", + " args:\n", + " encoder:\n", + " type: EfficientNet\n", + " args: null\n", + " num_decoder_layers: 4\n", + " hidden_dim: 256\n", + " num_heads: 4\n", + " expansion_dim: 1024\n", + " dropout_rate: 0.1\n", + " transformer_activation: glu\n", + "model:\n", + " desc: Configuration of the PyTorch Lightning model.\n", + " type: LitTransformerModel\n", + " args:\n", + " optimizer:\n", + " type: MADGRAD\n", + " args:\n", + " lr: 0.001\n", + " momentum: 0.9\n", + " weight_decay: 0\n", + " eps: 1.0e-06\n", + " lr_scheduler:\n", + " type: OneCycleLR\n", + " args:\n", + " interval: step\n", + " max_lr: 0.001\n", + " three_phase: true\n", + " epochs: 512\n", + " steps_per_epoch: 1246\n", + " criterion:\n", + " type: CrossEntropyLoss\n", + " args:\n", + " weight: None\n", + " ignore_index: -100\n", + " reduction: mean\n", + " monitor: val_loss\n", + " mapping: sentence_piece\n", + "data:\n", + " desc: Configuration of the training/test data.\n", + " type: IAMExtendedParagraphs\n", + " args:\n", + " batch_size: 16\n", + " num_workers: 12\n", + " train_fraction: 0.8\n", + " augment: true\n", + "callbacks:\n", + "- type: ModelCheckpoint\n", + " args:\n", + " monitor: val_loss\n", + " mode: min\n", + " save_last: true\n", + "- type: StochasticWeightAveraging\n", + " args:\n", + " swa_epoch_start: 0.8\n", + " swa_lrs: 0.05\n", + " annealing_epochs: 10\n", + " annealing_strategy: cos\n", + " device: null\n", + "- type: LearningRateMonitor\n", + " args:\n", + " logging_interval: step\n", + "- type: EarlyStopping\n", + " args:\n", + " monitor: val_loss\n", + " mode: min\n", + " patience: 10\n", + "trainer:\n", + " desc: Configuration of the PyTorch Lightning Trainer.\n", + " args:\n", + " stochastic_weight_avg: true\n", + " auto_scale_batch_size: binsearch\n", + " gradient_clip_val: 0\n", + " fast_dev_run: false\n", + " gpus: 1\n", + " precision: 16\n", + " max_epochs: 512\n", + " terminate_on_nan: true\n", + " weights_summary: true\n", + "load_checkpoint: null\n", + "\n" + ] } ], "source": [ - "xxx.transpose(1, 2).view(2, 32, 36, 40).shape" + "print(OmegaConf.to_yaml(conf))" ] }, { "cell_type": "code", - "execution_count": 18, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/plain": [ - "72.0" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": 8, + "metadata": {}, + "outputs": [], "source": [ - "576 / 8" + "from text_recognizer.networks.cnn_transformer import CNNTransformer" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 9, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "80.0" - ] - }, - "execution_count": 19, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "640 / 8" + "t = CNNTransformer(input_shape=(1, 576, 640), output_shape=(682, 1), **conf.network.args)" ] }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 1, 576, 640])" - ] - }, - "execution_count": 26, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "datum.shape" + "t.encode(datum).shape" ] }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 128, 18, 20])" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "vae.encoder(datum)[0].shape" + "trg.shape" ] }, { "cell_type": "code", - "execution_count": 87, + "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "torch.Size([2, 1, 576, 640])" + "torch.Size([2, 682, 1004])" ] }, - "execution_count": 87, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "vae(datum)[0].shape" + "t(datum, trg).shape" ] }, { diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb index eaf5397..add0b80 100644 --- a/notebooks/03-look-at-iam-paragraphs.ipynb +++ b/notebooks/03-look-at-iam-paragraphs.ipynb @@ -31,7 +31,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 2, "id": "726ac25b", "metadata": {}, "outputs": [], @@ -56,8 +56,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "2021-04-16 23:01:52.352 | INFO | text_recognizer.data.iam_paragraphs:setup:107 - Loading IAM paragraph regions and lines for None...\n", - "2021-04-16 23:02:08.521 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:79 - IAM Synthetic dataset steup for stage None\n" + "2021-04-25 23:17:44.177 | INFO | text_recognizer.data.iam_paragraphs:setup:107 - Loading IAM paragraph regions and lines for None...\n", + "2021-04-25 23:18:00.750 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:79 - IAM Synthetic dataset steup for stage None\n" ] }, { @@ -68,9 +68,9 @@ "Num classes: 84\n", "Dims: (1, 576, 640)\n", "Output dims: (682, 1)\n", - "Train/val/test sizes: 19912, 262, 231\n", - "Train Batch x stats: (torch.Size([1, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0043), tensor(0.0333), tensor(0.8588))\n", - "Train Batch y stats: (torch.Size([1, 682]), torch.int64, tensor(1), tensor(78))\n", + "Train/val/test sizes: 19948, 262, 231\n", + "Train Batch x stats: (torch.Size([1, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0109), tensor(0.0499), tensor(0.8314))\n", + "Train Batch y stats: (torch.Size([1, 682]), torch.int64, tensor(1), tensor(83))\n", "Test Batch x stats: (torch.Size([1, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0372), tensor(0.0767), tensor(0.8118))\n", "Test Batch y stats: (torch.Size([1, 682]), torch.int64, tensor(1), tensor(83))\n", "\n" @@ -86,10 +86,34 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "42501428", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2021-04-25 23:18:14.449 | INFO | text_recognizer.data.iam_paragraphs:setup:107 - Loading IAM paragraph regions and lines for None...\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "IAM Paragraphs Dataset\n", + "Num classes: 84\n", + "Input dims: (1, 576, 640)\n", + "Output dims: (682, 1)\n", + "Train/val/test sizes: 1046, 262, 231\n", + "Train Batch x stats: (torch.Size([16, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0393), tensor(0.0924), tensor(1.))\n", + "Train Batch y stats: (torch.Size([16, 682]), torch.int64, tensor(1), tensor(83))\n", + "Test Batch x stats: (torch.Size([16, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0312), tensor(0.0817), tensor(0.9294))\n", + "Test Batch y stats: (torch.Size([16, 682]), torch.int64, tensor(1), tensor(83))\n", + "\n" + ] + } + ], "source": [ "dataset = IAMParagraphs()\n", "dataset.prepare_data()\n", @@ -99,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "0cf22683", "metadata": {}, "outputs": [], @@ -107,6 +131,27 @@ "x, y = dataset.data_train[1]" ] }, + { + "cell_type": "code", + "execution_count": 6, + "id": "af7747a8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([682])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "y.shape" + ] + }, { "cell_type": "code", "execution_count": 7, -- cgit v1.2.3-70-g09d2