summaryrefslogtreecommitdiff
path: root/notebooks
diff options
context:
space:
mode:
Diffstat (limited to 'notebooks')
-rw-r--r--notebooks/00-testing-stuff-out.ipynb555
-rw-r--r--notebooks/03-look-at-iam-paragraphs.ipynb63
2 files changed, 201 insertions, 417 deletions
diff --git a/notebooks/00-testing-stuff-out.ipynb b/notebooks/00-testing-stuff-out.ipynb
index e6cf099..7c7b3a6 100644
--- a/notebooks/00-testing-stuff-out.ipynb
+++ b/notebooks/00-testing-stuff-out.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -25,7 +25,7 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -34,7 +34,7 @@
},
{
"cell_type": "code",
- "execution_count": 17,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -43,7 +43,7 @@
},
{
"cell_type": "code",
- "execution_count": 18,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -52,110 +52,16 @@
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "seed: 4711\n",
- "network:\n",
- " desc: Configuration of the PyTorch neural network.\n",
- " type: VQVAE\n",
- " args:\n",
- " in_channels: 1\n",
- " channels:\n",
- " - 32\n",
- " - 64\n",
- " - 96\n",
- " - 96\n",
- " - 128\n",
- " kernel_sizes:\n",
- " - 4\n",
- " - 4\n",
- " - 4\n",
- " - 4\n",
- " - 4\n",
- " strides:\n",
- " - 2\n",
- " - 2\n",
- " - 2\n",
- " - 2\n",
- " - 2\n",
- " num_residual_layers: 2\n",
- " embedding_dim: 128\n",
- " num_embeddings: 1024\n",
- " upsampling: null\n",
- " beta: 0.25\n",
- " activation: leaky_relu\n",
- " dropout_rate: 0.1\n",
- "model:\n",
- " desc: Configuration of the PyTorch Lightning model.\n",
- " type: LitVQVAEModel\n",
- " args:\n",
- " optimizer:\n",
- " type: MADGRAD\n",
- " args:\n",
- " lr: 0.001\n",
- " momentum: 0.9\n",
- " weight_decay: 0\n",
- " eps: 1.0e-06\n",
- " lr_scheduler:\n",
- " type: OneCycleLR\n",
- " args:\n",
- " interval: step\n",
- " max_lr: 0.001\n",
- " three_phase: true\n",
- " epochs: 1024\n",
- " steps_per_epoch: 317\n",
- " criterion:\n",
- " type: MSELoss\n",
- " args:\n",
- " reduction: mean\n",
- " monitor: val_loss\n",
- " mapping: sentence_piece\n",
- "data:\n",
- " desc: Configuration of the training/test data.\n",
- " type: IAMExtendedParagraphs\n",
- " args:\n",
- " batch_size: 64\n",
- " num_workers: 12\n",
- " train_fraction: 0.8\n",
- " augment: true\n",
- "callbacks:\n",
- "- type: ModelCheckpoint\n",
- " args:\n",
- " monitor: val_loss\n",
- " mode: min\n",
- " save_last: true\n",
- "- type: LearningRateMonitor\n",
- " args:\n",
- " logging_interval: step\n",
- "trainer:\n",
- " desc: Configuration of the PyTorch Lightning Trainer.\n",
- " args:\n",
- " stochastic_weight_avg: false\n",
- " auto_scale_batch_size: binsearch\n",
- " gradient_clip_val: 0\n",
- " fast_dev_run: false\n",
- " gpus: 1\n",
- " precision: 16\n",
- " max_epochs: 1024\n",
- " terminate_on_nan: true\n",
- " weights_summary: full\n",
- "load_checkpoint: null\n",
- "\n"
- ]
- }
- ],
+ "outputs": [],
"source": [
"print(OmegaConf.to_yaml(conf))"
]
},
{
"cell_type": "code",
- "execution_count": 20,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -164,7 +70,7 @@
},
{
"cell_type": "code",
- "execution_count": 21,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -173,167 +79,16 @@
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "VQVAE(\n",
- " (encoder): Encoder(\n",
- " (encoder): Sequential(\n",
- " (0): Sequential(\n",
- " (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (1): Dropout(p=0.1, inplace=False)\n",
- " (2): Sequential(\n",
- " (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " (4): Sequential(\n",
- " (0): Conv2d(64, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (5): Dropout(p=0.1, inplace=False)\n",
- " (6): Sequential(\n",
- " (0): Conv2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (7): Dropout(p=0.1, inplace=False)\n",
- " (8): Sequential(\n",
- " (0): Conv2d(96, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (9): Dropout(p=0.1, inplace=False)\n",
- " (10): _ResidualBlock(\n",
- " (block): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (1): ReLU(inplace=True)\n",
- " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " )\n",
- " (11): _ResidualBlock(\n",
- " (block): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (1): ReLU(inplace=True)\n",
- " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " )\n",
- " (12): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n",
- " )\n",
- " (vector_quantizer): VectorQuantizer(\n",
- " (embedding): Embedding(1024, 128)\n",
- " )\n",
- " )\n",
- " (decoder): Decoder(\n",
- " (res_block): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n",
- " (1): _ResidualBlock(\n",
- " (block): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (1): ReLU(inplace=True)\n",
- " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " )\n",
- " (2): _ResidualBlock(\n",
- " (block): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (1): ReLU(inplace=True)\n",
- " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " )\n",
- " )\n",
- " (upsampling_block): Sequential(\n",
- " (0): Sequential(\n",
- " (0): ConvTranspose2d(128, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (1): Dropout(p=0.1, inplace=False)\n",
- " (2): Sequential(\n",
- " (0): ConvTranspose2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " (4): Sequential(\n",
- " (0): ConvTranspose2d(96, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (5): Dropout(p=0.1, inplace=False)\n",
- " (6): Sequential(\n",
- " (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (7): Dropout(p=0.1, inplace=False)\n",
- " (8): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (9): Tanh()\n",
- " )\n",
- " (decoder): Sequential(\n",
- " (0): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1))\n",
- " (1): _ResidualBlock(\n",
- " (block): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (1): ReLU(inplace=True)\n",
- " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " )\n",
- " (2): _ResidualBlock(\n",
- " (block): Sequential(\n",
- " (0): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
- " (1): ReLU(inplace=True)\n",
- " (2): Conv2d(128, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " )\n",
- " )\n",
- " )\n",
- " (1): Sequential(\n",
- " (0): Sequential(\n",
- " (0): ConvTranspose2d(128, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (1): Dropout(p=0.1, inplace=False)\n",
- " (2): Sequential(\n",
- " (0): ConvTranspose2d(96, 96, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (3): Dropout(p=0.1, inplace=False)\n",
- " (4): Sequential(\n",
- " (0): ConvTranspose2d(96, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (5): Dropout(p=0.1, inplace=False)\n",
- " (6): Sequential(\n",
- " (0): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
- " )\n",
- " (7): Dropout(p=0.1, inplace=False)\n",
- " (8): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))\n",
- " (9): Tanh()\n",
- " )\n",
- " )\n",
- " )\n",
- ")"
- ]
- },
- "execution_count": 22,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"vae"
]
},
{
"cell_type": "code",
- "execution_count": 5,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
@@ -342,275 +97,259 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "proj = nn.Conv2d(1, 32, kernel_size=16, stride=16)"
+ "vae.encoder(datum)[0].shape"
]
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "x = proj(datum)"
+ "vae(datum)[0].shape"
]
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([2, 32, 36, 40])"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "x.shape"
+ "from text_recognizer.networks.backbones.efficientnet import EfficientNet"
]
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "xx = x.flatten(2)"
+ "en = EfficientNet()"
]
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": 3,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([2, 32, 1440])"
- ]
- },
- "execution_count": 10,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "xx.shape"
+ "datum = torch.randn([2, 1, 576, 640])"
]
},
{
"cell_type": "code",
- "execution_count": 11,
+ "execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
- "xxx = xx.transpose(1,2)"
+ "trg = torch.randint(0, 1000, [2, 682])"
]
},
{
"cell_type": "code",
- "execution_count": 12,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([2, 1440, 32])"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "xxx.shape"
+ "trg.shape"
]
},
{
"cell_type": "code",
- "execution_count": 13,
+ "execution_count": null,
"metadata": {},
"outputs": [],
"source": [
- "from einops import rearrange"
+ "datum = torch.randn([2, 1, 224, 224])"
]
},
{
"cell_type": "code",
- "execution_count": 14,
- "metadata": {},
+ "execution_count": null,
+ "metadata": {
+ "scrolled": false
+ },
"outputs": [],
"source": [
- "xxxx = rearrange(x, \"b c h w -> b ( h w ) c\")"
+ "en(datum).shape"
]
},
{
"cell_type": "code",
- "execution_count": 15,
+ "execution_count": 5,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([2, 1440, 32])"
- ]
- },
- "execution_count": 15,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "xxxx.shape"
+ "path = \"../training/configs/cnn_transformer.yaml\""
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
- " B, N, C = x.shape\n",
- " H, W = size\n",
- " assert N == 1 + H * W\n",
- "\n",
- " # Extract CLS token and image tokens.\n",
- " cls_token, img_tokens = x[:, :1], x[:, 1:] # Shape: [B, 1, C], [B, H*W, C].\n",
- " \n",
- " # Depthwise convolution.\n",
- " feat = img_tokens.transpose(1, 2).view(B, C, H, W)"
+ "conf = OmegaConf.load(path)"
]
},
{
"cell_type": "code",
- "execution_count": 22,
+ "execution_count": 7,
"metadata": {},
"outputs": [
{
- "data": {
- "text/plain": [
- "torch.Size([2, 32, 36, 40])"
- ]
- },
- "execution_count": 22,
- "metadata": {},
- "output_type": "execute_result"
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "seed: 4711\n",
+ "network:\n",
+ " desc: Configuration of the PyTorch neural network.\n",
+ " type: CNNTransformer\n",
+ " args:\n",
+ " encoder:\n",
+ " type: EfficientNet\n",
+ " args: null\n",
+ " num_decoder_layers: 4\n",
+ " hidden_dim: 256\n",
+ " num_heads: 4\n",
+ " expansion_dim: 1024\n",
+ " dropout_rate: 0.1\n",
+ " transformer_activation: glu\n",
+ "model:\n",
+ " desc: Configuration of the PyTorch Lightning model.\n",
+ " type: LitTransformerModel\n",
+ " args:\n",
+ " optimizer:\n",
+ " type: MADGRAD\n",
+ " args:\n",
+ " lr: 0.001\n",
+ " momentum: 0.9\n",
+ " weight_decay: 0\n",
+ " eps: 1.0e-06\n",
+ " lr_scheduler:\n",
+ " type: OneCycleLR\n",
+ " args:\n",
+ " interval: step\n",
+ " max_lr: 0.001\n",
+ " three_phase: true\n",
+ " epochs: 512\n",
+ " steps_per_epoch: 1246\n",
+ " criterion:\n",
+ " type: CrossEntropyLoss\n",
+ " args:\n",
+ " weight: None\n",
+ " ignore_index: -100\n",
+ " reduction: mean\n",
+ " monitor: val_loss\n",
+ " mapping: sentence_piece\n",
+ "data:\n",
+ " desc: Configuration of the training/test data.\n",
+ " type: IAMExtendedParagraphs\n",
+ " args:\n",
+ " batch_size: 16\n",
+ " num_workers: 12\n",
+ " train_fraction: 0.8\n",
+ " augment: true\n",
+ "callbacks:\n",
+ "- type: ModelCheckpoint\n",
+ " args:\n",
+ " monitor: val_loss\n",
+ " mode: min\n",
+ " save_last: true\n",
+ "- type: StochasticWeightAveraging\n",
+ " args:\n",
+ " swa_epoch_start: 0.8\n",
+ " swa_lrs: 0.05\n",
+ " annealing_epochs: 10\n",
+ " annealing_strategy: cos\n",
+ " device: null\n",
+ "- type: LearningRateMonitor\n",
+ " args:\n",
+ " logging_interval: step\n",
+ "- type: EarlyStopping\n",
+ " args:\n",
+ " monitor: val_loss\n",
+ " mode: min\n",
+ " patience: 10\n",
+ "trainer:\n",
+ " desc: Configuration of the PyTorch Lightning Trainer.\n",
+ " args:\n",
+ " stochastic_weight_avg: true\n",
+ " auto_scale_batch_size: binsearch\n",
+ " gradient_clip_val: 0\n",
+ " fast_dev_run: false\n",
+ " gpus: 1\n",
+ " precision: 16\n",
+ " max_epochs: 512\n",
+ " terminate_on_nan: true\n",
+ " weights_summary: true\n",
+ "load_checkpoint: null\n",
+ "\n"
+ ]
}
],
"source": [
- "xxx.transpose(1, 2).view(2, 32, 36, 40).shape"
+ "print(OmegaConf.to_yaml(conf))"
]
},
{
"cell_type": "code",
- "execution_count": 18,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "72.0"
- ]
- },
- "execution_count": 18,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
"source": [
- "576 / 8"
+ "from text_recognizer.networks.cnn_transformer import CNNTransformer"
]
},
{
"cell_type": "code",
- "execution_count": 19,
+ "execution_count": 9,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "80.0"
- ]
- },
- "execution_count": 19,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "640 / 8"
+ "t = CNNTransformer(input_shape=(1, 576, 640), output_shape=(682, 1), **conf.network.args)"
]
},
{
"cell_type": "code",
- "execution_count": 26,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([2, 1, 576, 640])"
- ]
- },
- "execution_count": 26,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "datum.shape"
+ "t.encode(datum).shape"
]
},
{
"cell_type": "code",
- "execution_count": 27,
+ "execution_count": null,
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([2, 128, 18, 20])"
- ]
- },
- "execution_count": 27,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
- "vae.encoder(datum)[0].shape"
+ "trg.shape"
]
},
{
"cell_type": "code",
- "execution_count": 87,
+ "execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
- "torch.Size([2, 1, 576, 640])"
+ "torch.Size([2, 682, 1004])"
]
},
- "execution_count": 87,
+ "execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "vae(datum)[0].shape"
+ "t(datum, trg).shape"
]
},
{
diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb
index eaf5397..add0b80 100644
--- a/notebooks/03-look-at-iam-paragraphs.ipynb
+++ b/notebooks/03-look-at-iam-paragraphs.ipynb
@@ -31,7 +31,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 2,
"id": "726ac25b",
"metadata": {},
"outputs": [],
@@ -56,8 +56,8 @@
"name": "stderr",
"output_type": "stream",
"text": [
- "2021-04-16 23:01:52.352 | INFO | text_recognizer.data.iam_paragraphs:setup:107 - Loading IAM paragraph regions and lines for None...\n",
- "2021-04-16 23:02:08.521 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:79 - IAM Synthetic dataset steup for stage None\n"
+ "2021-04-25 23:17:44.177 | INFO | text_recognizer.data.iam_paragraphs:setup:107 - Loading IAM paragraph regions and lines for None...\n",
+ "2021-04-25 23:18:00.750 | INFO | text_recognizer.data.iam_synthetic_paragraphs:setup:79 - IAM Synthetic dataset steup for stage None\n"
]
},
{
@@ -68,9 +68,9 @@
"Num classes: 84\n",
"Dims: (1, 576, 640)\n",
"Output dims: (682, 1)\n",
- "Train/val/test sizes: 19912, 262, 231\n",
- "Train Batch x stats: (torch.Size([1, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0043), tensor(0.0333), tensor(0.8588))\n",
- "Train Batch y stats: (torch.Size([1, 682]), torch.int64, tensor(1), tensor(78))\n",
+ "Train/val/test sizes: 19948, 262, 231\n",
+ "Train Batch x stats: (torch.Size([1, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0109), tensor(0.0499), tensor(0.8314))\n",
+ "Train Batch y stats: (torch.Size([1, 682]), torch.int64, tensor(1), tensor(83))\n",
"Test Batch x stats: (torch.Size([1, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0372), tensor(0.0767), tensor(0.8118))\n",
"Test Batch y stats: (torch.Size([1, 682]), torch.int64, tensor(1), tensor(83))\n",
"\n"
@@ -86,10 +86,34 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 4,
"id": "42501428",
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2021-04-25 23:18:14.449 | INFO | text_recognizer.data.iam_paragraphs:setup:107 - Loading IAM paragraph regions and lines for None...\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "IAM Paragraphs Dataset\n",
+ "Num classes: 84\n",
+ "Input dims: (1, 576, 640)\n",
+ "Output dims: (682, 1)\n",
+ "Train/val/test sizes: 1046, 262, 231\n",
+ "Train Batch x stats: (torch.Size([16, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0393), tensor(0.0924), tensor(1.))\n",
+ "Train Batch y stats: (torch.Size([16, 682]), torch.int64, tensor(1), tensor(83))\n",
+ "Test Batch x stats: (torch.Size([16, 1, 576, 640]), torch.float32, tensor(0.), tensor(0.0312), tensor(0.0817), tensor(0.9294))\n",
+ "Test Batch y stats: (torch.Size([16, 682]), torch.int64, tensor(1), tensor(83))\n",
+ "\n"
+ ]
+ }
+ ],
"source": [
"dataset = IAMParagraphs()\n",
"dataset.prepare_data()\n",
@@ -99,7 +123,7 @@
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 5,
"id": "0cf22683",
"metadata": {},
"outputs": [],
@@ -109,6 +133,27 @@
},
{
"cell_type": "code",
+ "execution_count": 6,
+ "id": "af7747a8",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([682])"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "y.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
"execution_count": 7,
"id": "e7778ae2",
"metadata": {