summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-03-20 18:09:06 +0100
commit7e8e54e84c63171e748bbf09516fd517e6821ace (patch)
tree996093f75a5d488dddf7ea1f159ed343a561ef89
parentb0719d84138b6bbe5f04a4982dfca673aea1a368 (diff)
Inital commit for refactoring to lightning
-rw-r--r--.gitattributes (renamed from src/.gitattributes)0
-rw-r--r--notebooks/00-testing-stuff-out.ipynb1469
-rw-r--r--notebooks/01-look-at-emnist.ipynb (renamed from src/notebooks/01-look-at-emnist.ipynb)0
-rw-r--r--notebooks/02a-sentence-generator.ipynb (renamed from src/notebooks/02a-sentence-generator.ipynb)0
-rw-r--r--notebooks/02b-emnist-lines-dataset.ipynb (renamed from src/notebooks/02b-emnist-lines-dataset.ipynb)0
-rw-r--r--notebooks/02c-image-patches.ipynb (renamed from src/notebooks/02c-image-patches.ipynb)0
-rw-r--r--notebooks/03a-line-prediction.ipynb (renamed from src/notebooks/03a-line-prediction.ipynb)0
-rw-r--r--notebooks/04a-look-at-iam-lines.ipynb (renamed from src/notebooks/04a-look-at-iam-lines.ipynb)0
-rw-r--r--notebooks/04b-look-at-iam-paragraphs-predictions.ipynb (renamed from src/notebooks/04b-look-at-iam-paragraphs-predictions.ipynb)0
-rw-r--r--notebooks/04b-look-at-iam-paragraphs.ipynb (renamed from src/notebooks/04b-look-at-iam-paragraphs.ipynb)0
-rw-r--r--notebooks/05-sanity-check-multihead-attention.ipynb (renamed from src/notebooks/05-sanity-check-multihead-attention.ipynb)0
-rw-r--r--notebooks/05a-UNet.ipynb (renamed from src/notebooks/05a-UNet.ipynb)0
-rw-r--r--notebooks/05a-test-end-to-end-model.ipynb (renamed from src/notebooks/05a-test-end-to-end-model.ipynb)0
-rw-r--r--notebooks/06-try-transformer-model-predictions.ipynb (renamed from src/notebooks/06-try-transformer-model-predictions.ipynb)0
-rw-r--r--notebooks/07-look-at-lexicon.ipynb (renamed from src/notebooks/07-look-at-lexicon.ipynb)0
-rw-r--r--notebooks/07-try-gtn.ipynb (renamed from src/notebooks/07-try-gtn.ipynb)0
-rw-r--r--notebooks/Untitled.ipynb (renamed from src/notebooks/Untitled.ipynb)0
-rw-r--r--notebooks/g1.png (renamed from src/notebooks/g1.png)bin8590 -> 8590 bytes
-rw-r--r--notebooks/g2.png (renamed from src/notebooks/g2.png)bin5247 -> 5247 bytes
-rw-r--r--notebooks/intersect.png (renamed from src/notebooks/intersect.png)bin7953 -> 7953 bytes
-rw-r--r--notebooks/intersection.pdf (renamed from src/notebooks/intersection.pdf)bin10154 -> 10154 bytes
-rw-r--r--noxfile.py20
-rw-r--r--poetry.lock216
-rw-r--r--pyproject.toml2
-rw-r--r--src/notebooks/00-testing-stuff-out.ipynb1059
-rw-r--r--src/text_recognizer/tests/support/emnist/8.pngbin498 -> 0 bytes
-rw-r--r--src/text_recognizer/tests/support/emnist/U.pngbin524 -> 0 bytes
-rw-r--r--src/text_recognizer/tests/support/emnist/e.pngbin563 -> 0 bytes
-rw-r--r--tasks/build_transitions.py (renamed from src/tasks/build_transitions.py)0
-rwxr-xr-xtasks/create_emnist_lines_datasets.sh (renamed from src/tasks/create_emnist_lines_datasets.sh)0
-rwxr-xr-xtasks/create_iam_paragraphs.sh (renamed from src/tasks/create_iam_paragraphs.sh)0
-rwxr-xr-xtasks/download_emnist.sh (renamed from src/tasks/download_emnist.sh)0
-rwxr-xr-xtasks/download_iam.sh (renamed from src/tasks/download_iam.sh)0
-rw-r--r--tasks/make_wordpieces.py (renamed from src/tasks/make_wordpieces.py)0
-rwxr-xr-xtasks/prepare_experiments.sh (renamed from src/tasks/prepare_experiments.sh)0
-rwxr-xr-xtasks/test_functionality.sh (renamed from src/tasks/test_functionality.sh)0
-rwxr-xr-xtasks/train.sh (renamed from src/tasks/train.sh)0
-rw-r--r--text_recognizer/__init__.py (renamed from src/text_recognizer/__init__.py)0
-rw-r--r--text_recognizer/character_predictor.py (renamed from src/text_recognizer/character_predictor.py)0
-rw-r--r--text_recognizer/datasets/__init__.py (renamed from src/text_recognizer/datasets/__init__.py)0
-rw-r--r--text_recognizer/datasets/dataset.py (renamed from src/text_recognizer/datasets/dataset.py)0
-rw-r--r--text_recognizer/datasets/emnist_dataset.py (renamed from src/text_recognizer/datasets/emnist_dataset.py)0
-rw-r--r--text_recognizer/datasets/emnist_essentials.json (renamed from src/text_recognizer/datasets/emnist_essentials.json)0
-rw-r--r--text_recognizer/datasets/emnist_lines_dataset.py (renamed from src/text_recognizer/datasets/emnist_lines_dataset.py)0
-rw-r--r--text_recognizer/datasets/iam_dataset.py (renamed from src/text_recognizer/datasets/iam_dataset.py)1
-rw-r--r--text_recognizer/datasets/iam_lines_dataset.py (renamed from src/text_recognizer/datasets/iam_lines_dataset.py)0
-rw-r--r--text_recognizer/datasets/iam_paragraphs_dataset.py (renamed from src/text_recognizer/datasets/iam_paragraphs_dataset.py)0
-rw-r--r--text_recognizer/datasets/iam_preprocessor.py (renamed from src/text_recognizer/datasets/iam_preprocessor.py)0
-rw-r--r--text_recognizer/datasets/sentence_generator.py (renamed from src/text_recognizer/datasets/sentence_generator.py)0
-rw-r--r--text_recognizer/datasets/transforms.py (renamed from src/text_recognizer/datasets/transforms.py)0
-rw-r--r--text_recognizer/datasets/util.py (renamed from src/text_recognizer/datasets/util.py)0
-rw-r--r--text_recognizer/line_predictor.py (renamed from src/text_recognizer/line_predictor.py)0
-rw-r--r--text_recognizer/models/__init__.py (renamed from src/text_recognizer/models/__init__.py)0
-rw-r--r--text_recognizer/models/base.py (renamed from src/text_recognizer/models/base.py)0
-rw-r--r--text_recognizer/models/character_model.py (renamed from src/text_recognizer/models/character_model.py)0
-rw-r--r--text_recognizer/models/crnn_model.py (renamed from src/text_recognizer/models/crnn_model.py)0
-rw-r--r--text_recognizer/models/ctc_transformer_model.py (renamed from src/text_recognizer/models/ctc_transformer_model.py)0
-rw-r--r--text_recognizer/models/segmentation_model.py (renamed from src/text_recognizer/models/segmentation_model.py)0
-rw-r--r--text_recognizer/models/transformer_model.py (renamed from src/text_recognizer/models/transformer_model.py)0
-rw-r--r--text_recognizer/models/vqvae_model.py (renamed from src/text_recognizer/models/vqvae_model.py)0
-rw-r--r--text_recognizer/networks/__init__.py (renamed from src/text_recognizer/networks/__init__.py)0
-rw-r--r--text_recognizer/networks/beam.py (renamed from src/text_recognizer/networks/beam.py)0
-rw-r--r--text_recognizer/networks/cnn.py (renamed from src/text_recognizer/networks/cnn.py)0
-rw-r--r--text_recognizer/networks/cnn_transformer.py (renamed from src/text_recognizer/networks/cnn_transformer.py)2
-rw-r--r--text_recognizer/networks/crnn.py (renamed from src/text_recognizer/networks/crnn.py)0
-rw-r--r--text_recognizer/networks/ctc.py (renamed from src/text_recognizer/networks/ctc.py)0
-rw-r--r--text_recognizer/networks/densenet.py (renamed from src/text_recognizer/networks/densenet.py)0
-rw-r--r--text_recognizer/networks/lenet.py (renamed from src/text_recognizer/networks/lenet.py)0
-rw-r--r--text_recognizer/networks/loss/__init__.py (renamed from src/text_recognizer/networks/loss/__init__.py)0
-rw-r--r--text_recognizer/networks/loss/loss.py (renamed from src/text_recognizer/networks/loss/loss.py)0
-rw-r--r--text_recognizer/networks/metrics.py (renamed from src/text_recognizer/networks/metrics.py)0
-rw-r--r--text_recognizer/networks/mlp.py (renamed from src/text_recognizer/networks/mlp.py)0
-rw-r--r--text_recognizer/networks/residual_network.py (renamed from src/text_recognizer/networks/residual_network.py)0
-rw-r--r--text_recognizer/networks/stn.py (renamed from src/text_recognizer/networks/stn.py)0
-rw-r--r--text_recognizer/networks/transducer/__init__.py (renamed from src/text_recognizer/networks/transducer/__init__.py)0
-rw-r--r--text_recognizer/networks/transducer/tds_conv.py (renamed from src/text_recognizer/networks/transducer/tds_conv.py)0
-rw-r--r--text_recognizer/networks/transducer/test.py (renamed from src/text_recognizer/networks/transducer/test.py)0
-rw-r--r--text_recognizer/networks/transducer/transducer.py (renamed from src/text_recognizer/networks/transducer/transducer.py)0
-rw-r--r--text_recognizer/networks/transformer/__init__.py (renamed from src/text_recognizer/networks/transformer/__init__.py)0
-rw-r--r--text_recognizer/networks/transformer/attention.py (renamed from src/text_recognizer/networks/transformer/attention.py)0
-rw-r--r--text_recognizer/networks/transformer/positional_encoding.py (renamed from src/text_recognizer/networks/transformer/positional_encoding.py)0
-rw-r--r--text_recognizer/networks/transformer/transformer.py (renamed from src/text_recognizer/networks/transformer/transformer.py)0
-rw-r--r--text_recognizer/networks/unet.py (renamed from src/text_recognizer/networks/unet.py)0
-rw-r--r--text_recognizer/networks/util.py (renamed from src/text_recognizer/networks/util.py)0
-rw-r--r--text_recognizer/networks/vit.py (renamed from src/text_recognizer/networks/vit.py)0
-rw-r--r--text_recognizer/networks/vq_transformer.py (renamed from src/text_recognizer/networks/vq_transformer.py)0
-rw-r--r--text_recognizer/networks/vqvae/__init__.py (renamed from src/text_recognizer/networks/vqvae/__init__.py)0
-rw-r--r--text_recognizer/networks/vqvae/decoder.py (renamed from src/text_recognizer/networks/vqvae/decoder.py)0
-rw-r--r--text_recognizer/networks/vqvae/encoder.py (renamed from src/text_recognizer/networks/vqvae/encoder.py)0
-rw-r--r--text_recognizer/networks/vqvae/vector_quantizer.py (renamed from src/text_recognizer/networks/vqvae/vector_quantizer.py)0
-rw-r--r--text_recognizer/networks/vqvae/vqvae.py (renamed from src/text_recognizer/networks/vqvae/vqvae.py)0
-rw-r--r--text_recognizer/networks/wide_resnet.py (renamed from src/text_recognizer/networks/wide_resnet.py)0
-rw-r--r--text_recognizer/paragraph_text_recognizer.py (renamed from src/text_recognizer/paragraph_text_recognizer.py)0
-rw-r--r--text_recognizer/tests/__init__.py (renamed from src/text_recognizer/tests/__init__.py)0
-rw-r--r--text_recognizer/tests/support/__init__.py (renamed from src/text_recognizer/tests/support/__init__.py)0
-rw-r--r--text_recognizer/tests/support/create_emnist_lines_support_files.py (renamed from src/text_recognizer/tests/support/create_emnist_lines_support_files.py)0
-rw-r--r--text_recognizer/tests/support/create_emnist_support_files.py (renamed from src/text_recognizer/tests/support/create_emnist_support_files.py)0
-rw-r--r--text_recognizer/tests/support/create_iam_lines_support_files.py (renamed from src/text_recognizer/tests/support/create_iam_lines_support_files.py)0
-rw-r--r--text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png (renamed from src/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png)bin2301 -> 2301 bytes
-rw-r--r--text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png (renamed from src/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png)bin5424 -> 5424 bytes
-rw-r--r--text_recognizer/tests/support/emnist_lines/they<eos>.png (renamed from src/text_recognizer/tests/support/emnist_lines/they<eos>.png)bin1391 -> 1391 bytes
-rw-r--r--text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png (renamed from src/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png)bin5170 -> 5170 bytes
-rw-r--r--text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png (renamed from src/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png)bin3617 -> 3617 bytes
-rw-r--r--text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png (renamed from src/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png)bin3923 -> 3923 bytes
-rw-r--r--text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg (renamed from src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg)bin14890 -> 14890 bytes
-rw-r--r--text_recognizer/tests/test_character_predictor.py (renamed from src/text_recognizer/tests/test_character_predictor.py)0
-rw-r--r--text_recognizer/tests/test_line_predictor.py (renamed from src/text_recognizer/tests/test_line_predictor.py)0
-rw-r--r--text_recognizer/tests/test_paragraph_text_recognizer.py (renamed from src/text_recognizer/tests/test_paragraph_text_recognizer.py)0
-rw-r--r--text_recognizer/util.py (renamed from src/text_recognizer/util.py)0
-rw-r--r--text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt (renamed from src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt)0
-rw-r--r--text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt (renamed from src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt)0
-rw-r--r--text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt (renamed from src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt)0
-rw-r--r--text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt (renamed from src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt)bin8588813 -> 8588813 bytes
-rw-r--r--text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt (renamed from src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt)bin92335101 -> 92335101 bytes
-rw-r--r--text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt (renamed from src/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt)bin21687018 -> 21687018 bytes
-rw-r--r--training/experiments/default_config_emnist.yml (renamed from src/training/experiments/default_config_emnist.yml)0
-rw-r--r--training/experiments/embedding_experiment.yml (renamed from src/training/experiments/embedding_experiment.yml)0
-rw-r--r--training/experiments/sample_experiment.yml (renamed from src/training/experiments/sample_experiment.yml)0
-rw-r--r--training/gpu_manager.py (renamed from src/training/gpu_manager.py)0
-rw-r--r--training/prepare_experiments.py (renamed from src/training/prepare_experiments.py)0
-rw-r--r--training/run_experiment.py (renamed from src/training/run_experiment.py)0
-rw-r--r--training/run_sweep.py (renamed from src/training/run_sweep.py)0
-rw-r--r--training/sweep_emnist.yml (renamed from src/training/sweep_emnist.yml)0
-rw-r--r--training/sweep_emnist_resnet.yml (renamed from src/training/sweep_emnist_resnet.yml)0
-rw-r--r--training/trainer/__init__.py (renamed from src/training/trainer/__init__.py)0
-rw-r--r--training/trainer/callbacks/__init__.py (renamed from src/training/trainer/callbacks/__init__.py)0
-rw-r--r--training/trainer/callbacks/base.py (renamed from src/training/trainer/callbacks/base.py)0
-rw-r--r--training/trainer/callbacks/checkpoint.py (renamed from src/training/trainer/callbacks/checkpoint.py)0
-rw-r--r--training/trainer/callbacks/early_stopping.py (renamed from src/training/trainer/callbacks/early_stopping.py)0
-rw-r--r--training/trainer/callbacks/lr_schedulers.py (renamed from src/training/trainer/callbacks/lr_schedulers.py)0
-rw-r--r--training/trainer/callbacks/progress_bar.py (renamed from src/training/trainer/callbacks/progress_bar.py)0
-rw-r--r--training/trainer/callbacks/wandb_callbacks.py (renamed from src/training/trainer/callbacks/wandb_callbacks.py)0
-rw-r--r--training/trainer/train.py (renamed from src/training/trainer/train.py)0
-rw-r--r--training/trainer/util.py (renamed from src/training/trainer/util.py)0
-rw-r--r--wandb/settings (renamed from src/wandb/settings)0
135 files changed, 1587 insertions, 1182 deletions
diff --git a/src/.gitattributes b/.gitattributes
index eebe826..eebe826 100644
--- a/src/.gitattributes
+++ b/.gitattributes
diff --git a/notebooks/00-testing-stuff-out.ipynb b/notebooks/00-testing-stuff-out.ipynb
new file mode 100644
index 0000000..becd918
--- /dev/null
+++ b/notebooks/00-testing-stuff-out.ipynb
@@ -0,0 +1,1469 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%load_ext autoreload\n",
+ "%autoreload 2\n",
+ "\n",
+ "%matplotlib inline\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "from PIL import Image\n",
+ "import torch.nn.functional as F\n",
+ "import torch\n",
+ "from torch import nn\n",
+ "from torchsummary import summary\n",
+ "from importlib.util import find_spec\n",
+ "if find_spec(\"text_recognizer\") is None:\n",
+ " import sys\n",
+ " sys.path.append('..')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks import CNN, TDS2d"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tds2d = TDS2d(**{\n",
+ " \"depth\" : 4,\n",
+ " \"tds_groups\" : [\n",
+ " { \"channels\" : 4, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n",
+ " { \"channels\" : 32, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n",
+ " { \"channels\" : 64, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n",
+ " { \"channels\" : 128, \"num_blocks\" : 3, \"stride\" : [2, 1] },\n",
+ " ],\n",
+ " \"kernel_size\" : [5, 7],\n",
+ " \"dropout_rate\" : 0.1\n",
+ " }, input_dim=32, output_dim=128)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "TDS2d(\n",
+ " (tds): Sequential(\n",
+ " (0): Conv2d(1, 16, kernel_size=[5, 7], stride=[2, 2], padding=(2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (4): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(4, 4, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=16, out_features=16, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=16, out_features=16, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (5): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(4, 4, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=16, out_features=16, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=16, out_features=16, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (6): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(4, 4, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=16, out_features=16, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=16, out_features=16, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (7): Conv2d(16, 128, kernel_size=[5, 7], stride=[2, 2], padding=(2, 3))\n",
+ " (8): ReLU(inplace=True)\n",
+ " (9): Dropout(p=0.1, inplace=False)\n",
+ " (10): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (11): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(32, 32, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=128, out_features=128, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=128, out_features=128, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (12): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(32, 32, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=128, out_features=128, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=128, out_features=128, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (13): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(32, 32, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=128, out_features=128, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=128, out_features=128, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (14): Conv2d(128, 256, kernel_size=[5, 7], stride=[2, 2], padding=(2, 3))\n",
+ " (15): ReLU(inplace=True)\n",
+ " (16): Dropout(p=0.1, inplace=False)\n",
+ " (17): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (18): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(64, 64, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=256, out_features=256, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=256, out_features=256, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (19): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(64, 64, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=256, out_features=256, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=256, out_features=256, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (20): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(64, 64, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=256, out_features=256, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=256, out_features=256, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (21): Conv2d(256, 512, kernel_size=[5, 7], stride=[2, 1], padding=(2, 3))\n",
+ " (22): ReLU(inplace=True)\n",
+ " (23): Dropout(p=0.1, inplace=False)\n",
+ " (24): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (25): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(128, 128, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (26): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(128, 128, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " (27): TDSBlock2d(\n",
+ " (conv): Sequential(\n",
+ " (0): Conv3d(128, 128, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (mlp): Sequential(\n",
+ " (0): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (1): ReLU(inplace=True)\n",
+ " (2): Dropout(p=0.1, inplace=False)\n",
+ " (3): Linear(in_features=512, out_features=512, bias=True)\n",
+ " (4): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (instance_norm): ModuleList(\n",
+ " (0): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (fc): Linear(in_features=1024, out_features=128, bias=True)\n",
+ ")"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tds2d"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "===============================================================================================\n",
+ "Layer (type:depth-idx) Output Shape Param #\n",
+ "===============================================================================================\n",
+ "├─Sequential: 1-1 [-1, 512, 2, 119] --\n",
+ "| └─Conv2d: 2-1 [-1, 16, 14, 476] 576\n",
+ "| └─ReLU: 2-2 [-1, 16, 14, 476] --\n",
+ "| └─Dropout: 2-3 [-1, 16, 14, 476] --\n",
+ "| └─InstanceNorm2d: 2-4 [-1, 16, 14, 476] 32\n",
+ "| └─TDSBlock2d: 2-5 [-1, 16, 14, 476] --\n",
+ "| | └─Sequential: 3-1 [-1, 4, 4, 14, 476] 564\n",
+ "| | └─Sequential: 3-2 [-1, 476, 14, 16] 544\n",
+ "| └─TDSBlock2d: 2-6 [-1, 16, 14, 476] --\n",
+ "| | └─Sequential: 3-3 [-1, 4, 4, 14, 476] 564\n",
+ "| | └─Sequential: 3-4 [-1, 476, 14, 16] 544\n",
+ "| └─TDSBlock2d: 2-7 [-1, 16, 14, 476] --\n",
+ "| | └─Sequential: 3-5 [-1, 4, 4, 14, 476] 564\n",
+ "| | └─Sequential: 3-6 [-1, 476, 14, 16] 544\n",
+ "| └─Conv2d: 2-8 [-1, 128, 7, 238] 71,808\n",
+ "| └─ReLU: 2-9 [-1, 128, 7, 238] --\n",
+ "| └─Dropout: 2-10 [-1, 128, 7, 238] --\n",
+ "| └─InstanceNorm2d: 2-11 [-1, 128, 7, 238] 256\n",
+ "| └─TDSBlock2d: 2-12 [-1, 128, 7, 238] --\n",
+ "| | └─Sequential: 3-7 [-1, 32, 4, 7, 238] 35,872\n",
+ "| | └─Sequential: 3-8 [-1, 238, 7, 128] 33,024\n",
+ "| └─TDSBlock2d: 2-13 [-1, 128, 7, 238] --\n",
+ "| | └─Sequential: 3-9 [-1, 32, 4, 7, 238] 35,872\n",
+ "| | └─Sequential: 3-10 [-1, 238, 7, 128] 33,024\n",
+ "| └─TDSBlock2d: 2-14 [-1, 128, 7, 238] --\n",
+ "| | └─Sequential: 3-11 [-1, 32, 4, 7, 238] 35,872\n",
+ "| | └─Sequential: 3-12 [-1, 238, 7, 128] 33,024\n",
+ "| └─Conv2d: 2-15 [-1, 256, 4, 119] 1,147,136\n",
+ "| └─ReLU: 2-16 [-1, 256, 4, 119] --\n",
+ "| └─Dropout: 2-17 [-1, 256, 4, 119] --\n",
+ "| └─InstanceNorm2d: 2-18 [-1, 256, 4, 119] 512\n",
+ "| └─TDSBlock2d: 2-19 [-1, 256, 4, 119] --\n",
+ "| | └─Sequential: 3-13 [-1, 64, 4, 4, 119] 143,424\n",
+ "| | └─Sequential: 3-14 [-1, 119, 4, 256] 131,584\n",
+ "| └─TDSBlock2d: 2-20 [-1, 256, 4, 119] --\n",
+ "| | └─Sequential: 3-15 [-1, 64, 4, 4, 119] 143,424\n",
+ "| | └─Sequential: 3-16 [-1, 119, 4, 256] 131,584\n",
+ "| └─TDSBlock2d: 2-21 [-1, 256, 4, 119] --\n",
+ "| | └─Sequential: 3-17 [-1, 64, 4, 4, 119] 143,424\n",
+ "| | └─Sequential: 3-18 [-1, 119, 4, 256] 131,584\n",
+ "| └─Conv2d: 2-22 [-1, 512, 2, 119] 4,588,032\n",
+ "| └─ReLU: 2-23 [-1, 512, 2, 119] --\n",
+ "| └─Dropout: 2-24 [-1, 512, 2, 119] --\n",
+ "| └─InstanceNorm2d: 2-25 [-1, 512, 2, 119] 1,024\n",
+ "| └─TDSBlock2d: 2-26 [-1, 512, 2, 119] --\n",
+ "| | └─Sequential: 3-19 [-1, 128, 4, 2, 119] 573,568\n",
+ "| | └─Sequential: 3-20 [-1, 119, 2, 512] 525,312\n",
+ "| └─TDSBlock2d: 2-27 [-1, 512, 2, 119] --\n",
+ "| | └─Sequential: 3-21 [-1, 128, 4, 2, 119] 573,568\n",
+ "| | └─Sequential: 3-22 [-1, 119, 2, 512] 525,312\n",
+ "| └─TDSBlock2d: 2-28 [-1, 512, 2, 119] --\n",
+ "| | └─Sequential: 3-23 [-1, 128, 4, 2, 119] 573,568\n",
+ "| | └─Sequential: 3-24 [-1, 119, 2, 512] 525,312\n",
+ "├─Linear: 1-2 [-1, 119, 128] 131,200\n",
+ "===============================================================================================\n",
+ "Total params: 10,272,252\n",
+ "Trainable params: 10,272,252\n",
+ "Non-trainable params: 0\n",
+ "Total mult-adds (G): 5.00\n",
+ "===============================================================================================\n",
+ "Input size (MB): 0.10\n",
+ "Forward/backward pass size (MB): 73.21\n",
+ "Params size (MB): 39.19\n",
+ "Estimated Total Size (MB): 112.50\n",
+ "===============================================================================================\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "===============================================================================================\n",
+ "Layer (type:depth-idx) Output Shape Param #\n",
+ "===============================================================================================\n",
+ "├─Sequential: 1-1 [-1, 512, 2, 119] --\n",
+ "| └─Conv2d: 2-1 [-1, 16, 14, 476] 576\n",
+ "| └─ReLU: 2-2 [-1, 16, 14, 476] --\n",
+ "| └─Dropout: 2-3 [-1, 16, 14, 476] --\n",
+ "| └─InstanceNorm2d: 2-4 [-1, 16, 14, 476] 32\n",
+ "| └─TDSBlock2d: 2-5 [-1, 16, 14, 476] --\n",
+ "| | └─Sequential: 3-1 [-1, 4, 4, 14, 476] 564\n",
+ "| | └─Sequential: 3-2 [-1, 476, 14, 16] 544\n",
+ "| └─TDSBlock2d: 2-6 [-1, 16, 14, 476] --\n",
+ "| | └─Sequential: 3-3 [-1, 4, 4, 14, 476] 564\n",
+ "| | └─Sequential: 3-4 [-1, 476, 14, 16] 544\n",
+ "| └─TDSBlock2d: 2-7 [-1, 16, 14, 476] --\n",
+ "| | └─Sequential: 3-5 [-1, 4, 4, 14, 476] 564\n",
+ "| | └─Sequential: 3-6 [-1, 476, 14, 16] 544\n",
+ "| └─Conv2d: 2-8 [-1, 128, 7, 238] 71,808\n",
+ "| └─ReLU: 2-9 [-1, 128, 7, 238] --\n",
+ "| └─Dropout: 2-10 [-1, 128, 7, 238] --\n",
+ "| └─InstanceNorm2d: 2-11 [-1, 128, 7, 238] 256\n",
+ "| └─TDSBlock2d: 2-12 [-1, 128, 7, 238] --\n",
+ "| | └─Sequential: 3-7 [-1, 32, 4, 7, 238] 35,872\n",
+ "| | └─Sequential: 3-8 [-1, 238, 7, 128] 33,024\n",
+ "| └─TDSBlock2d: 2-13 [-1, 128, 7, 238] --\n",
+ "| | └─Sequential: 3-9 [-1, 32, 4, 7, 238] 35,872\n",
+ "| | └─Sequential: 3-10 [-1, 238, 7, 128] 33,024\n",
+ "| └─TDSBlock2d: 2-14 [-1, 128, 7, 238] --\n",
+ "| | └─Sequential: 3-11 [-1, 32, 4, 7, 238] 35,872\n",
+ "| | └─Sequential: 3-12 [-1, 238, 7, 128] 33,024\n",
+ "| └─Conv2d: 2-15 [-1, 256, 4, 119] 1,147,136\n",
+ "| └─ReLU: 2-16 [-1, 256, 4, 119] --\n",
+ "| └─Dropout: 2-17 [-1, 256, 4, 119] --\n",
+ "| └─InstanceNorm2d: 2-18 [-1, 256, 4, 119] 512\n",
+ "| └─TDSBlock2d: 2-19 [-1, 256, 4, 119] --\n",
+ "| | └─Sequential: 3-13 [-1, 64, 4, 4, 119] 143,424\n",
+ "| | └─Sequential: 3-14 [-1, 119, 4, 256] 131,584\n",
+ "| └─TDSBlock2d: 2-20 [-1, 256, 4, 119] --\n",
+ "| | └─Sequential: 3-15 [-1, 64, 4, 4, 119] 143,424\n",
+ "| | └─Sequential: 3-16 [-1, 119, 4, 256] 131,584\n",
+ "| └─TDSBlock2d: 2-21 [-1, 256, 4, 119] --\n",
+ "| | └─Sequential: 3-17 [-1, 64, 4, 4, 119] 143,424\n",
+ "| | └─Sequential: 3-18 [-1, 119, 4, 256] 131,584\n",
+ "| └─Conv2d: 2-22 [-1, 512, 2, 119] 4,588,032\n",
+ "| └─ReLU: 2-23 [-1, 512, 2, 119] --\n",
+ "| └─Dropout: 2-24 [-1, 512, 2, 119] --\n",
+ "| └─InstanceNorm2d: 2-25 [-1, 512, 2, 119] 1,024\n",
+ "| └─TDSBlock2d: 2-26 [-1, 512, 2, 119] --\n",
+ "| | └─Sequential: 3-19 [-1, 128, 4, 2, 119] 573,568\n",
+ "| | └─Sequential: 3-20 [-1, 119, 2, 512] 525,312\n",
+ "| └─TDSBlock2d: 2-27 [-1, 512, 2, 119] --\n",
+ "| | └─Sequential: 3-21 [-1, 128, 4, 2, 119] 573,568\n",
+ "| | └─Sequential: 3-22 [-1, 119, 2, 512] 525,312\n",
+ "| └─TDSBlock2d: 2-28 [-1, 512, 2, 119] --\n",
+ "| | └─Sequential: 3-23 [-1, 128, 4, 2, 119] 573,568\n",
+ "| | └─Sequential: 3-24 [-1, 119, 2, 512] 525,312\n",
+ "├─Linear: 1-2 [-1, 119, 128] 131,200\n",
+ "===============================================================================================\n",
+ "Total params: 10,272,252\n",
+ "Trainable params: 10,272,252\n",
+ "Non-trainable params: 0\n",
+ "Total mult-adds (G): 5.00\n",
+ "===============================================================================================\n",
+ "Input size (MB): 0.10\n",
+ "Forward/backward pass size (MB): 73.21\n",
+ "Params size (MB): 39.19\n",
+ "Estimated Total Size (MB): 112.50\n",
+ "==============================================================================================="
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "summary(tds2d, (1, 28, 952), device=\"cpu\", depth=3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t = torch.randn(2,1, 28, 952)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "torch.Size([2, 119, 128])"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tds2d(t).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cnn = CNN().cuda()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "i = nn.Sequential(nn.Conv2d(1,1,1,1))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "nn.Sequential(i,i)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cnn(t).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks.vqvae import Encoder, Decoder, VQVAE"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "vqvae = VQVAE(1, [32, 128, 128, 256], [4, 4, 4, 4], [2, 2, [1, 2], [1, 2]], 2, 32, 256, [[6, 119], [7, 238]])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t = torch.randn(2, 1, 28, 952)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x, l = vqvae(t)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "5 * 59 / 10"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "x.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "summary(vqvae, (1, 28, 952), device=\"cpu\", depth=3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "up = nn.Upsample([4, 59])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "up(tt).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tt.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class GEGLU(nn.Module):\n",
+ " def __init__(self, dim_in, dim_out):\n",
+ " super().__init__()\n",
+ " self.proj = nn.Linear(dim_in, dim_out * 2)\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x, gate = self.proj(x).chunk(2, dim = -1)\n",
+ " return x * F.gelu(gate)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "e = GEGLU(256, 2048)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "e(t).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "emb = nn.Embedding(56, 256)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "with torch.no_grad():\n",
+ " e = emb(torch.Tensor([55]).long())"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from einops import repeat"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ee = repeat(e, \"() n -> b n\", b=16)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "emb.device"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ee"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ee.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t = torch.randn(16, 10, 256)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t = torch.cat((ee.unsqueeze(1), t, ee.unsqueeze(1)), dim=1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "e.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks.residual_network import IdentityBlock, ResidualBlock, BasicBlock, BottleNeckBlock, ResidualLayer, ResidualNetwork, ResidualNetworkEncoder"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks import WideResidualNetwork"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "wr = WideResidualNetwork(\n",
+ " in_channels= 1,\n",
+ " num_classes= 80,\n",
+ " in_planes=64,\n",
+ " depth=10,\n",
+ " num_layers=4,\n",
+ " width_factor=2,\n",
+ " num_stages=[64, 128, 256, 256],\n",
+ " dropout_rate= 0.1,\n",
+ " activation= \"SELU\",\n",
+ " use_decoder= False,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from torchsummary import summary"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "backbone = ResidualNetworkEncoder(1, [64, 65, 66, 67, 68], [2, 2, 2, 2, 2])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "summary(backbone, (1, 28, 952), device=\"cpu\", depth=3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ " backbone = nn.Sequential(\n",
+ " *list(wr.children())[:][:]\n",
+ " )\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "backbone"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "summary(wr, (1, 28, 952), device=\"cpu\", depth=3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "a = torch.rand(1, 1, 28, 952)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "b = wr(a)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from einops import rearrange"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "b = rearrange(b, \"b c h w -> b w c h\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "c = nn.AdaptiveAvgPool2d((None, 1))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "d = c(b)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "d.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "d.squeeze(3).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "b.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from torch import nn"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "32 + 64"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "3 * 112"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "col_embed = nn.Parameter(torch.rand(1000, 256 // 2))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "W, H = 196, 4"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "col_embed[:W].unsqueeze(0).repeat(H, 1, 1).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "col_embed[:H].unsqueeze(1).repeat(1, W, 1).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ " torch.cat(\n",
+ " [\n",
+ " col_embed[:W].unsqueeze(0).repeat(H, 1, 1),\n",
+ " col_embed[:H].unsqueeze(1).repeat(1, W, 1),\n",
+ " ],\n",
+ " dim=-1,\n",
+ " ).unsqueeze(0).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "4 * 196"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "target = torch.tensor([1,1,12,1,1,1,1,1,9,9,9,9,9,9])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "torch.nonzero(target == 9, as_tuple=False)[0].item()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "target[:9]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "np.inf"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks.transformer.positional_encoding import PositionalEncoding"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "plt.figure(figsize=(15, 5))\n",
+ "pe = PositionalEncoding(20, 0)\n",
+ "y = pe.forward(torch.zeros(1, 100, 20))\n",
+ "plt.plot(np.arange(100), y[0, :, 4:8].data.numpy())\n",
+ "plt.legend([\"dim %d\"%p for p in [4,5,6,7]])\n",
+ "None"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks.densenet import DenseNet,_DenseLayer,_DenseBlock"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dnet = DenseNet(12, (6, 12, 10), 1, 24, 80, 4, 0, True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "216 / 8"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "summary(dnet, (1, 28, 952), device=\"cpu\", depth=3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ " backbone = nn.Sequential(\n",
+ " *list(dnet.children())[:][:-4]\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "backbone"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.networks import WideResidualNetwork"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "w = WideResidualNetwork(\n",
+ " in_channels = 1,\n",
+ " in_planes = 32,\n",
+ " num_classes = 80,\n",
+ " depth = 10,\n",
+ " width_factor = 1,\n",
+ " dropout_rate = 0.0,\n",
+ " num_layers = 5,\n",
+ " activation = \"relu\",\n",
+ " use_decoder = False,)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "summary(w, (1, 28, 952), device=\"cpu\", depth=2)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "sz= 5"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mask = torch.triu(torch.ones(sz, sz), 1)\n",
+ "mask = mask.masked_fill(mask==1, float('-inf'))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "\n",
+ "h = torch.rand(1, 256, 10, 10)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "h.flatten(2).permute(2, 0, 1).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "h.flatten(2).permute(2, 0, 1).shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mask\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pred = torch.Tensor([1,21,2,45,31, 81, 1, 79, 79, 79, 2,1,1,1,1, 81, 1, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()\n",
+ "target = torch.Tensor([1,1,1,1,1, 81, 79, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mask = (target != 79)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "mask"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pred * mask"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "target * mask"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from text_recognizer.models.metrics import accuracy"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pad_indcies = torch.nonzero(target == 79, as_tuple=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t1 = torch.nonzero(target == 81, as_tuple=False).squeeze(1)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "target.shape[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t2 = torch.arange(10, target.shape[0] + 1, 10)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "t2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "for start, stop in zip(t1, t2):\n",
+ " pred[start+1:stop] = 79"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pred"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "[pred[start+1:stop] = 79 for start, stop in zip(t1, t2)]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [],
+ "source": [
+ "pad_indcies"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pred[pad_indcies:pad_indcies] = 79"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "pred.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "target.shape"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "accuracy(pred, target)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "acc = (pred == target).sum().float() / target.shape[0]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "acc"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.1"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/src/notebooks/01-look-at-emnist.ipynb b/notebooks/01-look-at-emnist.ipynb
index b70ce12..b70ce12 100644
--- a/src/notebooks/01-look-at-emnist.ipynb
+++ b/notebooks/01-look-at-emnist.ipynb
diff --git a/src/notebooks/02a-sentence-generator.ipynb b/notebooks/02a-sentence-generator.ipynb
index 99aa56a..99aa56a 100644
--- a/src/notebooks/02a-sentence-generator.ipynb
+++ b/notebooks/02a-sentence-generator.ipynb
diff --git a/src/notebooks/02b-emnist-lines-dataset.ipynb b/notebooks/02b-emnist-lines-dataset.ipynb
index f82342b..f82342b 100644
--- a/src/notebooks/02b-emnist-lines-dataset.ipynb
+++ b/notebooks/02b-emnist-lines-dataset.ipynb
diff --git a/src/notebooks/02c-image-patches.ipynb b/notebooks/02c-image-patches.ipynb
index fedea91..fedea91 100644
--- a/src/notebooks/02c-image-patches.ipynb
+++ b/notebooks/02c-image-patches.ipynb
diff --git a/src/notebooks/03a-line-prediction.ipynb b/notebooks/03a-line-prediction.ipynb
index 13f4ff1..13f4ff1 100644
--- a/src/notebooks/03a-line-prediction.ipynb
+++ b/notebooks/03a-line-prediction.ipynb
diff --git a/src/notebooks/04a-look-at-iam-lines.ipynb b/notebooks/04a-look-at-iam-lines.ipynb
index de59a85..de59a85 100644
--- a/src/notebooks/04a-look-at-iam-lines.ipynb
+++ b/notebooks/04a-look-at-iam-lines.ipynb
diff --git a/src/notebooks/04b-look-at-iam-paragraphs-predictions.ipynb b/notebooks/04b-look-at-iam-paragraphs-predictions.ipynb
index 5662eb1..5662eb1 100644
--- a/src/notebooks/04b-look-at-iam-paragraphs-predictions.ipynb
+++ b/notebooks/04b-look-at-iam-paragraphs-predictions.ipynb
diff --git a/src/notebooks/04b-look-at-iam-paragraphs.ipynb b/notebooks/04b-look-at-iam-paragraphs.ipynb
index dc0aef6..dc0aef6 100644
--- a/src/notebooks/04b-look-at-iam-paragraphs.ipynb
+++ b/notebooks/04b-look-at-iam-paragraphs.ipynb
diff --git a/src/notebooks/05-sanity-check-multihead-attention.ipynb b/notebooks/05-sanity-check-multihead-attention.ipynb
index 54f0432..54f0432 100644
--- a/src/notebooks/05-sanity-check-multihead-attention.ipynb
+++ b/notebooks/05-sanity-check-multihead-attention.ipynb
diff --git a/src/notebooks/05a-UNet.ipynb b/notebooks/05a-UNet.ipynb
index 77d895d..77d895d 100644
--- a/src/notebooks/05a-UNet.ipynb
+++ b/notebooks/05a-UNet.ipynb
diff --git a/src/notebooks/05a-test-end-to-end-model.ipynb b/notebooks/05a-test-end-to-end-model.ipynb
index 7723b12..7723b12 100644
--- a/src/notebooks/05a-test-end-to-end-model.ipynb
+++ b/notebooks/05a-test-end-to-end-model.ipynb
diff --git a/src/notebooks/06-try-transformer-model-predictions.ipynb b/notebooks/06-try-transformer-model-predictions.ipynb
index d39e111..d39e111 100644
--- a/src/notebooks/06-try-transformer-model-predictions.ipynb
+++ b/notebooks/06-try-transformer-model-predictions.ipynb
diff --git a/src/notebooks/07-look-at-lexicon.ipynb b/notebooks/07-look-at-lexicon.ipynb
index b7a5a0e..b7a5a0e 100644
--- a/src/notebooks/07-look-at-lexicon.ipynb
+++ b/notebooks/07-look-at-lexicon.ipynb
diff --git a/src/notebooks/07-try-gtn.ipynb b/notebooks/07-try-gtn.ipynb
index 4ef444b..4ef444b 100644
--- a/src/notebooks/07-try-gtn.ipynb
+++ b/notebooks/07-try-gtn.ipynb
diff --git a/src/notebooks/Untitled.ipynb b/notebooks/Untitled.ipynb
index 841a37d..841a37d 100644
--- a/src/notebooks/Untitled.ipynb
+++ b/notebooks/Untitled.ipynb
diff --git a/src/notebooks/g1.png b/notebooks/g1.png
index 09dd49e..09dd49e 100644
--- a/src/notebooks/g1.png
+++ b/notebooks/g1.png
Binary files differ
diff --git a/src/notebooks/g2.png b/notebooks/g2.png
index a3cf21e..a3cf21e 100644
--- a/src/notebooks/g2.png
+++ b/notebooks/g2.png
Binary files differ
diff --git a/src/notebooks/intersect.png b/notebooks/intersect.png
index 63b7f2f..63b7f2f 100644
--- a/src/notebooks/intersect.png
+++ b/notebooks/intersect.png
Binary files differ
diff --git a/src/notebooks/intersection.pdf b/notebooks/intersection.pdf
index c425a9f..c425a9f 100644
--- a/src/notebooks/intersection.pdf
+++ b/notebooks/intersection.pdf
Binary files differ
diff --git a/noxfile.py b/noxfile.py
index 60c3923..098a551 100644
--- a/noxfile.py
+++ b/noxfile.py
@@ -40,7 +40,7 @@ def install_with_constraints(session: Session, *args: str, **kwargs: Any) -> Non
session.install(f"--constraint={requirements.name}", *args, **kwargs)
-@nox.session(python="3.8")
+@nox.session(python="3.9")
def black(session: Session) -> None:
"""Run black code formatter."""
args = session.posargs or locations
@@ -48,7 +48,7 @@ def black(session: Session) -> None:
session.run("black", *args)
-@nox.session(python=["3.8"])
+@nox.session(python=["3.9"])
def lint(session: Session) -> None:
"""Lint using flake8."""
args = session.posargs or locations
@@ -66,7 +66,7 @@ def lint(session: Session) -> None:
session.run("flake8", *args)
-@nox.session(python="3.8")
+@nox.session(python="3.9")
def safety(session: Session) -> None:
"""Scan dependencies for insecure packages."""
with tempfile.NamedTemporaryFile() as requirements:
@@ -83,7 +83,7 @@ def safety(session: Session) -> None:
session.run("safety", "check", f"--file={requirements.name}", "--full-report")
-@nox.session(python=["3.8"])
+@nox.session(python=["3.9"])
def mypy(session: Session) -> None:
"""Type-check using mypy."""
args = session.posargs or locations
@@ -91,7 +91,7 @@ def mypy(session: Session) -> None:
session.run("mypy", *args)
-@nox.session(python="3.8")
+@nox.session(python="3.9")
def pytype(session: Session) -> None:
"""Type-check using pytype."""
args = session.posargs or ["--disable=import-error", *locations]
@@ -99,7 +99,7 @@ def pytype(session: Session) -> None:
session.run("pytype", *args)
-@nox.session(python=["3.8"])
+@nox.session(python=["3.9"])
def tests(session: Session) -> None:
"""Run the test suite."""
args = session.posargs or ["--cov", "-m", "not e2e"]
@@ -110,7 +110,7 @@ def tests(session: Session) -> None:
session.run("pytest", *args)
-@nox.session(python=["3.8"])
+@nox.session(python=["3.9"])
def typeguard(session: Session) -> None:
"""Runtime type checking using Typeguard."""
args = session.posargs or ["-m", "not e2e"]
@@ -119,7 +119,7 @@ def typeguard(session: Session) -> None:
session.run("pytest", f"--typeguard-packages={package}", *args)
-@nox.session(python=["3.8"])
+@nox.session(python=["3.9"])
def xdoctest(session: Session) -> None:
"""Run examples with xdoctest."""
args = session.posargs or ["all"]
@@ -128,7 +128,7 @@ def xdoctest(session: Session) -> None:
session.run("python", "-m", "xdoctest", package, *args)
-@nox.session(python="3.8")
+@nox.session(python="3.9")
def coverage(session: Session) -> None:
"""Upload coverage data."""
install_with_constraints(session, "coverage[toml]", "codecov")
@@ -136,7 +136,7 @@ def coverage(session: Session) -> None:
session.run("codecov", *session.posargs)
-@nox.session(python="3.8")
+@nox.session(python="3.9")
def docs(session: Session) -> None:
"""Build the documentation."""
session.run("poetry", "install", "--no-dev", external=True)
diff --git a/poetry.lock b/poetry.lock
index 72da168..78f086e 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -244,14 +244,6 @@ optional = false
python-versions = ">=3.6,<4.0"
[[package]]
-name = "dataclasses"
-version = "0.6"
-description = "A backport of the dataclasses module for Python 3.6"
-category = "main"
-optional = false
-python-versions = "*"
-
-[[package]]
name = "decorator"
version = "4.4.2"
description = "Decorators for Humans"
@@ -432,14 +424,6 @@ python-versions = "*"
flake8 = "*"
[[package]]
-name = "future"
-version = "0.18.2"
-description = "Clean single-source support for Python 3 and 2"
-category = "main"
-optional = false
-python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
-
-[[package]]
name = "gitdb"
version = "4.0.5"
description = "Git Object Database"
@@ -501,15 +485,17 @@ python-versions = ">=3.5"
[[package]]
name = "h5py"
-version = "2.10.0"
+version = "3.2.1"
description = "Read and write HDF5 files from Python"
category = "main"
optional = false
-python-versions = "*"
+python-versions = ">=3.7"
[package.dependencies]
-numpy = ">=1.7"
-six = "*"
+numpy = [
+ {version = ">=1.17.5", markers = "python_version == \"3.8\""},
+ {version = ">=1.19.3", markers = "python_version >= \"3.9\""},
+]
[[package]]
name = "idna"
@@ -995,11 +981,11 @@ test = ["nose", "coverage", "requests", "nose-warnings-filters", "nbval", "nose-
[[package]]
name = "numpy"
-version = "1.19.4"
+version = "1.20.1"
description = "NumPy is the fundamental package for array computing with Python."
category = "main"
optional = false
-python-versions = ">=3.6"
+python-versions = ">=3.7"
[[package]]
name = "nvidia-ml-py3"
@@ -1029,6 +1015,9 @@ category = "main"
optional = false
python-versions = ">=3.6"
+[package.dependencies]
+numpy = ">=1.19.3"
+
[[package]]
name = "packaging"
version = "20.4"
@@ -1782,15 +1771,13 @@ python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*"
[[package]]
name = "torch"
-version = "1.7.0"
+version = "1.7.1"
description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration"
category = "main"
optional = false
-python-versions = ">=3.6.1"
+python-versions = ">=3.6.2"
[package.dependencies]
-dataclasses = "*"
-future = "*"
numpy = "*"
typing-extensions = "*"
@@ -1804,7 +1791,7 @@ python-versions = ">=3.5"
[[package]]
name = "torchvision"
-version = "0.8.1"
+version = "0.8.2"
description = "image and video datasets and models for torch deep learning"
category = "main"
optional = false
@@ -1813,7 +1800,7 @@ python-versions = "*"
[package.dependencies]
numpy = "*"
pillow = ">=4.1.1"
-torch = "1.7.0"
+torch = "1.7.1"
[package.extras]
scipy = ["scipy"]
@@ -2006,7 +1993,7 @@ tests = ["pytest", "pytest-cov", "codecov", "scikit-build", "cmake", "ninja", "p
[metadata]
lock-version = "1.1"
python-versions = "^3.8"
-content-hash = "1f194d7de179e9676ef1f8e51b83ff15c001627803008ef8225e8e14ab3acab0"
+content-hash = "c87742a388e1277e84313b4c0ff75681d754c8328db2c488c0aba2a4dafc6a64"
[metadata.files]
alabaster = [
@@ -2038,6 +2025,8 @@ argon2-cffi = [
{file = "argon2_cffi-20.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:6678bb047373f52bcff02db8afab0d2a77d83bde61cfecea7c5c62e2335cb203"},
{file = "argon2_cffi-20.1.0-cp38-cp38-win32.whl", hash = "sha256:77e909cc756ef81d6abb60524d259d959bab384832f0c651ed7dcb6e5ccdbb78"},
{file = "argon2_cffi-20.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:9dfd5197852530294ecb5795c97a823839258dfd5eb9420233c7cfedec2058f2"},
+ {file = "argon2_cffi-20.1.0-cp39-cp39-win32.whl", hash = "sha256:e2db6e85c057c16d0bd3b4d2b04f270a7467c147381e8fd73cbbe5bc719832be"},
+ {file = "argon2_cffi-20.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:8a84934bd818e14a17943de8099d41160da4a336bcc699bb4c394bbb9b94bd32"},
]
async-generator = [
{file = "async_generator-1.10-py3-none-any.whl", hash = "sha256:01c7bf666359b4967d2cda0000cc2e4af16a0ae098cbffcb8472fb9e8ad6585b"},
@@ -2182,10 +2171,6 @@ darglint = [
{file = "darglint-1.5.6-py3-none-any.whl", hash = "sha256:6fcef385e646c4da9ea6fc547e28c77a33ae0cba4806b8585ae18a490a797e82"},
{file = "darglint-1.5.6.tar.gz", hash = "sha256:98acb4064bae73ec02146cb123dd3c930bd5272e562ad4d19c59857443632dd1"},
]
-dataclasses = [
- {file = "dataclasses-0.6-py3-none-any.whl", hash = "sha256:454a69d788c7fda44efd71e259be79577822f5e3f53f029a22d08004e951dc9f"},
- {file = "dataclasses-0.6.tar.gz", hash = "sha256:6988bd2b895eef432d562370bb707d540f32f7360ab13da45340101bc2307d84"},
-]
decorator = [
{file = "decorator-4.4.2-py2.py3-none-any.whl", hash = "sha256:41fa54c2a0cc4ba648be4fd43cff00aedf5b9465c9bf18d64325bc225f08f760"},
{file = "decorator-4.4.2.tar.gz", hash = "sha256:e3a62f0520172440ca0dcc823749319382e377f37f140a0b99ef45fecb84bfe7"},
@@ -2248,9 +2233,6 @@ flake8-polyfill = [
{file = "flake8-polyfill-1.0.2.tar.gz", hash = "sha256:e44b087597f6da52ec6393a709e7108b2905317d0c0b744cdca6208e670d8eda"},
{file = "flake8_polyfill-1.0.2-py2.py3-none-any.whl", hash = "sha256:12be6a34ee3ab795b19ca73505e7b55826d5f6ad7230d31b18e106400169b9e9"},
]
-future = [
- {file = "future-0.18.2.tar.gz", hash = "sha256:b1bead90b70cf6ec3f0710ae53a525360fa360d306a86583adc6bf83a4db537d"},
-]
gitdb = [
{file = "gitdb-4.0.5-py3-none-any.whl", hash = "sha256:91f36bfb1ab7949b3b40e23736db18231bf7593edada2ba5c3a174a7b23657ac"},
{file = "gitdb-4.0.5.tar.gz", hash = "sha256:c9e1f2d0db7ddb9a704c2a0217be31214e91a4fe1dea1efad19ae42ba0c285c9"},
@@ -2270,35 +2252,16 @@ gtn = [
{file = "gtn-0.0.0.tar.gz", hash = "sha256:72fece9ca51df161c1274e570d6f5f933e76f4cac9d8d6dd543a3fe0383f7268"},
]
h5py = [
- {file = "h5py-2.10.0-cp27-cp27m-macosx_10_6_intel.whl", hash = "sha256:ecf4d0b56ee394a0984de15bceeb97cbe1fe485f1ac205121293fc44dcf3f31f"},
- {file = "h5py-2.10.0-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:86868dc07b9cc8cb7627372a2e6636cdc7a53b7e2854ad020c9e9d8a4d3fd0f5"},
- {file = "h5py-2.10.0-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:aac4b57097ac29089f179bbc2a6e14102dd210618e94d77ee4831c65f82f17c0"},
- {file = "h5py-2.10.0-cp27-cp27m-win32.whl", hash = "sha256:7be5754a159236e95bd196419485343e2b5875e806fe68919e087b6351f40a70"},
- {file = "h5py-2.10.0-cp27-cp27m-win_amd64.whl", hash = "sha256:13c87efa24768a5e24e360a40e0bc4c49bcb7ce1bb13a3a7f9902cec302ccd36"},
- {file = "h5py-2.10.0-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:79b23f47c6524d61f899254f5cd5e486e19868f1823298bc0c29d345c2447172"},
- {file = "h5py-2.10.0-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:cbf28ae4b5af0f05aa6e7551cee304f1d317dbed1eb7ac1d827cee2f1ef97a99"},
- {file = "h5py-2.10.0-cp34-cp34m-manylinux1_i686.whl", hash = "sha256:c0d4b04bbf96c47b6d360cd06939e72def512b20a18a8547fa4af810258355d5"},
- {file = "h5py-2.10.0-cp34-cp34m-manylinux1_x86_64.whl", hash = "sha256:549ad124df27c056b2e255ea1c44d30fb7a17d17676d03096ad5cd85edb32dc1"},
- {file = "h5py-2.10.0-cp35-cp35m-macosx_10_6_intel.whl", hash = "sha256:a5f82cd4938ff8761d9760af3274acf55afc3c91c649c50ab18fcff5510a14a5"},
- {file = "h5py-2.10.0-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:3dad1730b6470fad853ef56d755d06bb916ee68a3d8272b3bab0c1ddf83bb99e"},
- {file = "h5py-2.10.0-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:063947eaed5f271679ed4ffa36bb96f57bc14f44dd4336a827d9a02702e6ce6b"},
- {file = "h5py-2.10.0-cp35-cp35m-win32.whl", hash = "sha256:c54a2c0dd4957776ace7f95879d81582298c5daf89e77fb8bee7378f132951de"},
- {file = "h5py-2.10.0-cp35-cp35m-win_amd64.whl", hash = "sha256:6998be619c695910cb0effe5eb15d3a511d3d1a5d217d4bd0bebad1151ec2262"},
- {file = "h5py-2.10.0-cp36-cp36m-macosx_10_6_intel.whl", hash = "sha256:ff7d241f866b718e4584fa95f520cb19405220c501bd3a53ee11871ba5166ea2"},
- {file = "h5py-2.10.0-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:54817b696e87eb9e403e42643305f142cd8b940fe9b3b490bbf98c3b8a894cf4"},
- {file = "h5py-2.10.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:d3c59549f90a891691991c17f8e58c8544060fdf3ccdea267100fa5f561ff62f"},
- {file = "h5py-2.10.0-cp36-cp36m-win32.whl", hash = "sha256:d7ae7a0576b06cb8e8a1c265a8bc4b73d05fdee6429bffc9a26a6eb531e79d72"},
- {file = "h5py-2.10.0-cp36-cp36m-win_amd64.whl", hash = "sha256:bffbc48331b4a801d2f4b7dac8a72609f0b10e6e516e5c480a3e3241e091c878"},
- {file = "h5py-2.10.0-cp37-cp37m-macosx_10_6_intel.whl", hash = "sha256:51ae56894c6c93159086ffa2c94b5b3388c0400548ab26555c143e7cfa05b8e5"},
- {file = "h5py-2.10.0-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:16ead3c57141101e3296ebeed79c9c143c32bdd0e82a61a2fc67e8e6d493e9d1"},
- {file = "h5py-2.10.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:f0e25bb91e7a02efccb50aba6591d3fe2c725479e34769802fcdd4076abfa917"},
- {file = "h5py-2.10.0-cp37-cp37m-win32.whl", hash = "sha256:f23951a53d18398ef1344c186fb04b26163ca6ce449ebd23404b153fd111ded9"},
- {file = "h5py-2.10.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8bb1d2de101f39743f91512a9750fb6c351c032e5cd3204b4487383e34da7f75"},
- {file = "h5py-2.10.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:64f74da4a1dd0d2042e7d04cf8294e04ddad686f8eba9bb79e517ae582f6668d"},
- {file = "h5py-2.10.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:d35f7a3a6cefec82bfdad2785e78359a0e6a5fbb3f605dd5623ce88082ccd681"},
- {file = "h5py-2.10.0-cp38-cp38-win32.whl", hash = "sha256:6ef7ab1089e3ef53ca099038f3c0a94d03e3560e6aff0e9d6c64c55fb13fc681"},
- {file = "h5py-2.10.0-cp38-cp38-win_amd64.whl", hash = "sha256:769e141512b54dee14ec76ed354fcacfc7d97fea5a7646b709f7400cf1838630"},
- {file = "h5py-2.10.0.tar.gz", hash = "sha256:84412798925dc870ffd7107f045d7659e60f5d46d1c70c700375248bf6bf512d"},
+ {file = "h5py-3.2.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6766104ed13ff40b3b7bfd49f13fced5274103ee9af53667e7a97c5236b14741"},
+ {file = "h5py-3.2.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:4160cb0d35a83c6fb9f1cad65e826dfaeb044e001549ea78003573fb6bee4042"},
+ {file = "h5py-3.2.1-cp37-cp37m-win_amd64.whl", hash = "sha256:fdabe99139a9c5e1a416b7ed38c89505f8501b376d54496e1bb737cb33df61cf"},
+ {file = "h5py-3.2.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d8467fa56356ad2efad2b5986326e71d4d74505de6f6c7bb46dbba09b37459ac"},
+ {file = "h5py-3.2.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:a6632ac11167bbad1a8fc5c82508b97ab8c12bdfe4b659254b6f7f63d3c76744"},
+ {file = "h5py-3.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:90ee8a00aca5c4e0bbd821c1f6118cb9a814c15dcfdb03572c615a4431166480"},
+ {file = "h5py-3.2.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:25294f2690c4813475f566663a21ef1c1b11ef892b26d46454bf0a59e507d5aa"},
+ {file = "h5py-3.2.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:d791b710d3e54c4d2c32cb881b183db5674ceb03bf6a0c1f3fb3cf50d8997e0a"},
+ {file = "h5py-3.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:7c5b5f18c96fb63399280a724734fd91e1781c6b60e385e439ad8e654a294ba4"},
+ {file = "h5py-3.2.1.tar.gz", hash = "sha256:89474be911bfcdb34cbf0d98b8ec48b578c27a89fdb1ae4ee7513f1ef8d9249e"},
]
idna = [
{file = "idna-2.10-py2.py3-none-any.whl", hash = "sha256:b97d804b1e9b523befed77c48dacec60e6dcb0b5391d57af6a65a312a90648c0"},
@@ -2426,20 +2389,39 @@ markupsafe = [
{file = "MarkupSafe-1.1.1-cp35-cp35m-win32.whl", hash = "sha256:6dd73240d2af64df90aa7c4e7481e23825ea70af4b4922f8ede5b9e35f78a3b1"},
{file = "MarkupSafe-1.1.1-cp35-cp35m-win_amd64.whl", hash = "sha256:9add70b36c5666a2ed02b43b335fe19002ee5235efd4b8a89bfcf9005bebac0d"},
{file = "MarkupSafe-1.1.1-cp36-cp36m-macosx_10_6_intel.whl", hash = "sha256:24982cc2533820871eba85ba648cd53d8623687ff11cbb805be4ff7b4c971aff"},
+ {file = "MarkupSafe-1.1.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:d53bc011414228441014aa71dbec320c66468c1030aae3a6e29778a3382d96e5"},
{file = "MarkupSafe-1.1.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:00bc623926325b26bb9605ae9eae8a215691f33cae5df11ca5424f06f2d1f473"},
{file = "MarkupSafe-1.1.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:717ba8fe3ae9cc0006d7c451f0bb265ee07739daf76355d06366154ee68d221e"},
+ {file = "MarkupSafe-1.1.1-cp36-cp36m-manylinux2010_i686.whl", hash = "sha256:3b8a6499709d29c2e2399569d96719a1b21dcd94410a586a18526b143ec8470f"},
+ {file = "MarkupSafe-1.1.1-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:84dee80c15f1b560d55bcfe6d47b27d070b4681c699c572af2e3c7cc90a3b8e0"},
+ {file = "MarkupSafe-1.1.1-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:b1dba4527182c95a0db8b6060cc98ac49b9e2f5e64320e2b56e47cb2831978c7"},
{file = "MarkupSafe-1.1.1-cp36-cp36m-win32.whl", hash = "sha256:535f6fc4d397c1563d08b88e485c3496cf5784e927af890fb3c3aac7f933ec66"},
{file = "MarkupSafe-1.1.1-cp36-cp36m-win_amd64.whl", hash = "sha256:b1282f8c00509d99fef04d8ba936b156d419be841854fe901d8ae224c59f0be5"},
{file = "MarkupSafe-1.1.1-cp37-cp37m-macosx_10_6_intel.whl", hash = "sha256:8defac2f2ccd6805ebf65f5eeb132adcf2ab57aa11fdf4c0dd5169a004710e7d"},
+ {file = "MarkupSafe-1.1.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:bf5aa3cbcfdf57fa2ee9cd1822c862ef23037f5c832ad09cfea57fa846dec193"},
{file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:46c99d2de99945ec5cb54f23c8cd5689f6d7177305ebff350a58ce5f8de1669e"},
{file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:ba59edeaa2fc6114428f1637ffff42da1e311e29382d81b339c1817d37ec93c6"},
+ {file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:6fffc775d90dcc9aed1b89219549b329a9250d918fd0b8fa8d93d154918422e1"},
+ {file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:a6a744282b7718a2a62d2ed9d993cad6f5f585605ad352c11de459f4108df0a1"},
+ {file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:195d7d2c4fbb0ee8139a6cf67194f3973a6b3042d742ebe0a9ed36d8b6f0c07f"},
{file = "MarkupSafe-1.1.1-cp37-cp37m-win32.whl", hash = "sha256:b00c1de48212e4cc9603895652c5c410df699856a2853135b3967591e4beebc2"},
{file = "MarkupSafe-1.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9bf40443012702a1d2070043cb6291650a0841ece432556f784f004937f0f32c"},
{file = "MarkupSafe-1.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6788b695d50a51edb699cb55e35487e430fa21f1ed838122d722e0ff0ac5ba15"},
{file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:cdb132fc825c38e1aeec2c8aa9338310d29d337bebbd7baa06889d09a60a1fa2"},
{file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:13d3144e1e340870b25e7b10b98d779608c02016d5184cfb9927a9f10c689f42"},
+ {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:acf08ac40292838b3cbbb06cfe9b2cb9ec78fce8baca31ddb87aaac2e2dc3bc2"},
+ {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:d9be0ba6c527163cbed5e0857c451fcd092ce83947944d6c14bc95441203f032"},
+ {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:caabedc8323f1e93231b52fc32bdcde6db817623d33e100708d9a68e1f53b26b"},
{file = "MarkupSafe-1.1.1-cp38-cp38-win32.whl", hash = "sha256:596510de112c685489095da617b5bcbbac7dd6384aeebeda4df6025d0256a81b"},
{file = "MarkupSafe-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:e8313f01ba26fbbe36c7be1966a7b7424942f670f38e666995b88d012765b9be"},
+ {file = "MarkupSafe-1.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d73a845f227b0bfe8a7455ee623525ee656a9e2e749e4742706d80a6065d5e2c"},
+ {file = "MarkupSafe-1.1.1-cp39-cp39-manylinux1_i686.whl", hash = "sha256:98bae9582248d6cf62321dcb52aaf5d9adf0bad3b40582925ef7c7f0ed85fceb"},
+ {file = "MarkupSafe-1.1.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:2beec1e0de6924ea551859edb9e7679da6e4870d32cb766240ce17e0a0ba2014"},
+ {file = "MarkupSafe-1.1.1-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:7fed13866cf14bba33e7176717346713881f56d9d2bcebab207f7a036f41b850"},
+ {file = "MarkupSafe-1.1.1-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:6f1e273a344928347c1290119b493a1f0303c52f5a5eae5f16d74f48c15d4a85"},
+ {file = "MarkupSafe-1.1.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:feb7b34d6325451ef96bc0e36e1a6c0c1c64bc1fbec4b854f4529e51887b1621"},
+ {file = "MarkupSafe-1.1.1-cp39-cp39-win32.whl", hash = "sha256:22c178a091fc6630d0d045bdb5992d2dfe14e3259760e713c490da5323866c39"},
+ {file = "MarkupSafe-1.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:b7d644ddb4dbd407d31ffb699f1d140bc35478da613b441c582aeb7c43838dd8"},
{file = "MarkupSafe-1.1.1.tar.gz", hash = "sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b"},
]
marshmallow = [
@@ -2529,40 +2511,30 @@ notebook = [
{file = "notebook-6.1.5.tar.gz", hash = "sha256:3db37ae834c5f3b6378381229d0e5dfcbfb558d08c8ce646b1ad355147f5e91d"},
]
numpy = [
- {file = "numpy-1.19.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:e9b30d4bd69498fc0c3fe9db5f62fffbb06b8eb9321f92cc970f2969be5e3949"},
- {file = "numpy-1.19.4-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:fedbd128668ead37f33917820b704784aff695e0019309ad446a6d0b065b57e4"},
- {file = "numpy-1.19.4-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:8ece138c3a16db8c1ad38f52eb32be6086cc72f403150a79336eb2045723a1ad"},
- {file = "numpy-1.19.4-cp36-cp36m-manylinux2010_i686.whl", hash = "sha256:64324f64f90a9e4ef732be0928be853eee378fd6a01be21a0a8469c4f2682c83"},
- {file = "numpy-1.19.4-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:ad6f2ff5b1989a4899bf89800a671d71b1612e5ff40866d1f4d8bcf48d4e5764"},
- {file = "numpy-1.19.4-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:d6c7bb82883680e168b55b49c70af29b84b84abb161cbac2800e8fcb6f2109b6"},
- {file = "numpy-1.19.4-cp36-cp36m-win32.whl", hash = "sha256:13d166f77d6dc02c0a73c1101dd87fdf01339febec1030bd810dcd53fff3b0f1"},
- {file = "numpy-1.19.4-cp36-cp36m-win_amd64.whl", hash = "sha256:448ebb1b3bf64c0267d6b09a7cba26b5ae61b6d2dbabff7c91b660c7eccf2bdb"},
- {file = "numpy-1.19.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:27d3f3b9e3406579a8af3a9f262f5339005dd25e0ecf3cf1559ff8a49ed5cbf2"},
- {file = "numpy-1.19.4-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:16c1b388cc31a9baa06d91a19366fb99ddbe1c7b205293ed072211ee5bac1ed2"},
- {file = "numpy-1.19.4-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:e5b6ed0f0b42317050c88022349d994fe72bfe35f5908617512cd8c8ef9da2a9"},
- {file = "numpy-1.19.4-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:18bed2bcb39e3f758296584337966e68d2d5ba6aab7e038688ad53c8f889f757"},
- {file = "numpy-1.19.4-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:fe45becb4c2f72a0907c1d0246ea6449fe7a9e2293bb0e11c4e9a32bb0930a15"},
- {file = "numpy-1.19.4-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:6d7593a705d662be5bfe24111af14763016765f43cb6923ed86223f965f52387"},
- {file = "numpy-1.19.4-cp37-cp37m-win32.whl", hash = "sha256:6ae6c680f3ebf1cf7ad1d7748868b39d9f900836df774c453c11c5440bc15b36"},
- {file = "numpy-1.19.4-cp37-cp37m-win_amd64.whl", hash = "sha256:9eeb7d1d04b117ac0d38719915ae169aa6b61fca227b0b7d198d43728f0c879c"},
- {file = "numpy-1.19.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cb1017eec5257e9ac6209ac172058c430e834d5d2bc21961dceeb79d111e5909"},
- {file = "numpy-1.19.4-cp38-cp38-manylinux1_i686.whl", hash = "sha256:edb01671b3caae1ca00881686003d16c2209e07b7ef8b7639f1867852b948f7c"},
- {file = "numpy-1.19.4-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:f29454410db6ef8126c83bd3c968d143304633d45dc57b51252afbd79d700893"},
- {file = "numpy-1.19.4-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:ec149b90019852266fec2341ce1db513b843e496d5a8e8cdb5ced1923a92faab"},
- {file = "numpy-1.19.4-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:1aeef46a13e51931c0b1cf8ae1168b4a55ecd282e6688fdb0a948cc5a1d5afb9"},
- {file = "numpy-1.19.4-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:08308c38e44cc926bdfce99498b21eec1f848d24c302519e64203a8da99a97db"},
- {file = "numpy-1.19.4-cp38-cp38-win32.whl", hash = "sha256:5734bdc0342aba9dfc6f04920988140fb41234db42381cf7ccba64169f9fe7ac"},
- {file = "numpy-1.19.4-cp38-cp38-win_amd64.whl", hash = "sha256:09c12096d843b90eafd01ea1b3307e78ddd47a55855ad402b157b6c4862197ce"},
- {file = "numpy-1.19.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e452dc66e08a4ce642a961f134814258a082832c78c90351b75c41ad16f79f63"},
- {file = "numpy-1.19.4-cp39-cp39-manylinux1_i686.whl", hash = "sha256:a5d897c14513590a85774180be713f692df6fa8ecf6483e561a6d47309566f37"},
- {file = "numpy-1.19.4-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:a09f98011236a419ee3f49cedc9ef27d7a1651df07810ae430a6b06576e0b414"},
- {file = "numpy-1.19.4-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:50e86c076611212ca62e5a59f518edafe0c0730f7d9195fec718da1a5c2bb1fc"},
- {file = "numpy-1.19.4-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:f0d3929fe88ee1c155129ecd82f981b8856c5d97bcb0d5f23e9b4242e79d1de3"},
- {file = "numpy-1.19.4-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:c42c4b73121caf0ed6cd795512c9c09c52a7287b04d105d112068c1736d7c753"},
- {file = "numpy-1.19.4-cp39-cp39-win32.whl", hash = "sha256:8cac8790a6b1ddf88640a9267ee67b1aee7a57dfa2d2dd33999d080bc8ee3a0f"},
- {file = "numpy-1.19.4-cp39-cp39-win_amd64.whl", hash = "sha256:4377e10b874e653fe96985c05feed2225c912e328c8a26541f7fc600fb9c637b"},
- {file = "numpy-1.19.4-pp36-pypy36_pp73-manylinux2010_x86_64.whl", hash = "sha256:2a2740aa9733d2e5b2dfb33639d98a64c3b0f24765fed86b0fd2aec07f6a0a08"},
- {file = "numpy-1.19.4.zip", hash = "sha256:141ec3a3300ab89c7f2b0775289954d193cc8edb621ea05f99db9cb181530512"},
+ {file = "numpy-1.20.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:ae61f02b84a0211abb56462a3b6cd1e7ec39d466d3160eb4e1da8bf6717cdbeb"},
+ {file = "numpy-1.20.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:65410c7f4398a0047eea5cca9b74009ea61178efd78d1be9847fac1d6716ec1e"},
+ {file = "numpy-1.20.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:2d7e27442599104ee08f4faed56bb87c55f8b10a5494ac2ead5c98a4b289e61f"},
+ {file = "numpy-1.20.1-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:4ed8e96dc146e12c1c5cdd6fb9fd0757f2ba66048bf94c5126b7efebd12d0090"},
+ {file = "numpy-1.20.1-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:ecb5b74c702358cdc21268ff4c37f7466357871f53a30e6f84c686952bef16a9"},
+ {file = "numpy-1.20.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:b9410c0b6fed4a22554f072a86c361e417f0258838957b78bd063bde2c7f841f"},
+ {file = "numpy-1.20.1-cp37-cp37m-win32.whl", hash = "sha256:3d3087e24e354c18fb35c454026af3ed8997cfd4997765266897c68d724e4845"},
+ {file = "numpy-1.20.1-cp37-cp37m-win_amd64.whl", hash = "sha256:89f937b13b8dd17b0099c7c2e22066883c86ca1575a975f754babc8fbf8d69a9"},
+ {file = "numpy-1.20.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a1d7995d1023335e67fb070b2fae6f5968f5be3802b15ad6d79d81ecaa014fe0"},
+ {file = "numpy-1.20.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:60759ab15c94dd0e1ed88241fd4fa3312db4e91d2c8f5a2d4cf3863fad83d65b"},
+ {file = "numpy-1.20.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:125a0e10ddd99a874fd357bfa1b636cd58deb78ba4a30b5ddb09f645c3512e04"},
+ {file = "numpy-1.20.1-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:c26287dfc888cf1e65181f39ea75e11f42ffc4f4529e5bd19add57ad458996e2"},
+ {file = "numpy-1.20.1-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:7199109fa46277be503393be9250b983f325880766f847885607d9b13848f257"},
+ {file = "numpy-1.20.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:72251e43ac426ff98ea802a931922c79b8d7596480300eb9f1b1e45e0543571e"},
+ {file = "numpy-1.20.1-cp38-cp38-win32.whl", hash = "sha256:c91ec9569facd4757ade0888371eced2ecf49e7982ce5634cc2cf4e7331a4b14"},
+ {file = "numpy-1.20.1-cp38-cp38-win_amd64.whl", hash = "sha256:13adf545732bb23a796914fe5f891a12bd74cf3d2986eed7b7eba2941eea1590"},
+ {file = "numpy-1.20.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:104f5e90b143dbf298361a99ac1af4cf59131218a045ebf4ee5990b83cff5fab"},
+ {file = "numpy-1.20.1-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:89e5336f2bec0c726ac7e7cdae181b325a9c0ee24e604704ed830d241c5e47ff"},
+ {file = "numpy-1.20.1-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:032be656d89bbf786d743fee11d01ef318b0781281241997558fa7950028dd29"},
+ {file = "numpy-1.20.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:66b467adfcf628f66ea4ac6430ded0614f5cc06ba530d09571ea404789064adc"},
+ {file = "numpy-1.20.1-cp39-cp39-win32.whl", hash = "sha256:12e4ba5c6420917571f1a5becc9338abbde71dd811ce40b37ba62dec7b39af6d"},
+ {file = "numpy-1.20.1-cp39-cp39-win_amd64.whl", hash = "sha256:9c94cab5054bad82a70b2e77741271790304651d584e2cdfe2041488e753863b"},
+ {file = "numpy-1.20.1-pp37-pypy37_pp73-manylinux2010_x86_64.whl", hash = "sha256:9eb551d122fadca7774b97db8a112b77231dcccda8e91a5bc99e79890797175e"},
+ {file = "numpy-1.20.1.zip", hash = "sha256:3bc63486a870294683980d76ec1e3efc786295ae00128f9ea38e2c6e74d5a60a"},
]
nvidia-ml-py3 = [
{file = "nvidia-ml-py3-7.352.0.tar.gz", hash = "sha256:390f02919ee9d73fe63a98c73101061a6b37fa694a793abf56673320f1f51277"},
@@ -2803,6 +2775,8 @@ pyyaml = [
{file = "PyYAML-5.3.1-cp37-cp37m-win_amd64.whl", hash = "sha256:73f099454b799e05e5ab51423c7bcf361c58d3206fa7b0d555426b1f4d9a3eaf"},
{file = "PyYAML-5.3.1-cp38-cp38-win32.whl", hash = "sha256:06a0d7ba600ce0b2d2fe2e78453a470b5a6e000a985dd4a4e54e436cc36b0e97"},
{file = "PyYAML-5.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:95f71d2af0ff4227885f7a6605c37fd53d3a106fcab511b8860ecca9fcf400ee"},
+ {file = "PyYAML-5.3.1-cp39-cp39-win32.whl", hash = "sha256:ad9c67312c84def58f3c04504727ca879cb0013b2517c85a9a253f0cb6380c0a"},
+ {file = "PyYAML-5.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:6034f55dab5fea9e53f436aa68fa3ace2634918e8b5994d82f3621c04ff5ed2e"},
{file = "PyYAML-5.3.1.tar.gz", hash = "sha256:b8eac752c5e14d3eca0e6dd9199cd627518cb5ec06add0de9d32baeee6fe645d"},
]
pyzmq = [
@@ -2822,11 +2796,13 @@ pyzmq = [
{file = "pyzmq-20.0.0-cp37-cp37m-win32.whl", hash = "sha256:c95dda497a7c1b1e734b5e8353173ca5dd7b67784d8821d13413a97856588057"},
{file = "pyzmq-20.0.0-cp37-cp37m-win_amd64.whl", hash = "sha256:cc09c5cd1a4332611c8564d65e6a432dc6db3e10793d0254da9fa1e31d9ffd6d"},
{file = "pyzmq-20.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6e24907857c80dc67692e31f5bf3ad5bf483ee0142cec95b3d47e2db8c43bdda"},
+ {file = "pyzmq-20.0.0-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:53706f4a792cdae422121fb6a5e65119bad02373153364fc9d004cf6a90394de"},
{file = "pyzmq-20.0.0-cp38-cp38-manylinux1_i686.whl", hash = "sha256:895695be380f0f85d2e3ec5ccf68a93c92d45bd298567525ad5633071589872c"},
{file = "pyzmq-20.0.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:d92c7f41a53ece82b91703ea433c7d34143248cf0cead33aa11c5fc621c764bf"},
{file = "pyzmq-20.0.0-cp38-cp38-win32.whl", hash = "sha256:309d763d89ec1845c0e0fa14e1fb6558fd8c9ef05ed32baec27d7a8499cc7bb0"},
{file = "pyzmq-20.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:0e554fd390021edbe0330b67226325a820b0319c5b45e1b0a59bf22ccc36e793"},
{file = "pyzmq-20.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cfa54a162a7b32641665e99b2c12084555afe9fc8fe80ec8b2f71a57320d10e1"},
+ {file = "pyzmq-20.0.0-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:dc2f48b575dff6edefd572f1ac84cf0c3f18ad5fcf13384de32df740a010594a"},
{file = "pyzmq-20.0.0-cp39-cp39-manylinux1_i686.whl", hash = "sha256:5efe02bdcc5eafcac0aab531292294298f0ab8d28ed43be9e507d0e09173d1a4"},
{file = "pyzmq-20.0.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:0af84f34f27b5c6a0e906c648bdf46d4caebf9c8e6e16db0728f30a58141cad6"},
{file = "pyzmq-20.0.0-cp39-cp39-win32.whl", hash = "sha256:c63fafd2556d218368c51d18588f8e6f8d86d09d493032415057faf6de869b34"},
@@ -3052,6 +3028,7 @@ stevedore = [
]
subprocess32 = [
{file = "subprocess32-3.5.4-cp27-cp27m-macosx_10_6_intel.whl", hash = "sha256:88e37c1aac5388df41cc8a8456bb49ebffd321a3ad4d70358e3518176de3a56b"},
+ {file = "subprocess32-3.5.4-cp27-cp27mu-manylinux2014_x86_64.whl", hash = "sha256:e45d985aef903c5b7444d34350b05da91a9e0ea015415ab45a21212786c649d0"},
{file = "subprocess32-3.5.4.tar.gz", hash = "sha256:eb2937c80497978d181efa1b839ec2d9622cf9600a039a79d0e108d1f9aec79d"},
]
terminado = [
@@ -3071,24 +3048,32 @@ toml = [
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
]
torch = [
- {file = "torch-1.7.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:6b0c9b56cb56afe3ecbac79351d21c6f7172dffc7b7daa8c365f660541baf1a5"},
- {file = "torch-1.7.0-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:e8cc3b2c3937b7ae036a3b447a189af049bfc006bca054fc1d8ae78766ca3105"},
- {file = "torch-1.7.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:1520c48430dea38e5845b7b3defc9054edad45f1f245808aa268ade840bb2c2a"},
- {file = "torch-1.7.0-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:89cb8774243750bd3fd2b3b3d09bab6e3be68b1785ad48b8411f1eb4fc7acdba"},
- {file = "torch-1.7.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:11054f26eee5c3114d217201dba5b3a35f1745d11133c123c077c5981bc95997"},
- {file = "torch-1.7.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:b8000e39600e101b2f19dbbab75de663a3b78e3979c3e1720b7136aae1c35ce2"},
+ {file = "torch-1.7.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:422e64e98d0e100c360993819d0307e5d56e9517b26135808ad68984d577d75a"},
+ {file = "torch-1.7.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f0aaf657145533824b15f2fd8fde8f8c67fe6c6281088ef588091f03fad90243"},
+ {file = "torch-1.7.1-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:af464a6f4314a875035e0c4c2b07517599704b214634f4ed3ad2e748c5ef291f"},
+ {file = "torch-1.7.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:5d76c255a41484c1d41a9ff570b9c9f36cb85df9428aa15a58ae16ac7cfc2ea6"},
+ {file = "torch-1.7.1-cp37-cp37m-win_amd64.whl", hash = "sha256:d241c3f1c4d563e4ba86f84769c23e12606db167ee6f674eedff6d02901462e3"},
+ {file = "torch-1.7.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:de84b4166e3f7335eb868b51d3bbd909ec33828af27290b4171bce832a55be3c"},
+ {file = "torch-1.7.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:dd2fc6880c95e836960d86efbbc7f63d3287f2e1893c51d31f96dbfe02f0d73e"},
+ {file = "torch-1.7.1-cp38-cp38-win_amd64.whl", hash = "sha256:e000b94be3aa58ad7f61e7d07cf379ea9366cf6c6874e68bd58ad0bdc537b3a7"},
+ {file = "torch-1.7.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:2e49cac969976be63117004ee00d0a3e3dd4ea662ad77383f671b8992825de1a"},
+ {file = "torch-1.7.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:a3793dcceb12b1e2281290cca1277c5ce86ddfd5bf044f654285a4d69057aea7"},
+ {file = "torch-1.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:6652a767a0572ae0feb74ad128758e507afd3b8396b6e7f147e438ba8d4c6f63"},
+ {file = "torch-1.7.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:38d67f4fb189a92a977b2c0a38e4f6dd413e0bf55aa6d40004696df7e40a71ff"},
]
torch-summary = [
{file = "torch-summary-1.4.3.tar.gz", hash = "sha256:2dcbc1dfd07dca9f4080bcacdaf90db3f2fc28efee348c8fba9033039b0e8c82"},
{file = "torch_summary-1.4.3-py3-none-any.whl", hash = "sha256:a0a76916bd11d054fd3863dc7c474971922badfbc13d6404f9eddd297041f094"},
]
torchvision = [
- {file = "torchvision-0.8.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:80b1c6d0a97e86454c15cf9f1afcf0751761273b7687c3d0910336ea87cca8d4"},
- {file = "torchvision-0.8.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:307daa1daa4cc1a2380dd26f81d3a9670535fff8927f1049dc76d4e47253fb8e"},
- {file = "torchvision-0.8.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b58262a2bd2d419d94d7bf8aaa3a532b9283f4995e766723cc4cc3a52d8883c8"},
- {file = "torchvision-0.8.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:95b0ce59e631e2c97e6069dff126a43232cca859b18a1b505e5b02dd1a65dd0f"},
- {file = "torchvision-0.8.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:469e0b831bfe17c46159966b5dc7ba09c87eaeecbed6f9a4d6ec4e691b0c8827"},
- {file = "torchvision-0.8.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:337820e680e5193872903369d8177d5ea681e7156d370d89d487b0e0f1e56238"},
+ {file = "torchvision-0.8.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:86fae370d222f76ad57c57c3bee03f78b8db727743bfb4c1559a3d395159cea8"},
+ {file = "torchvision-0.8.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:951239b5fcb911dbf78c1385d677f5f48c7a1b12859e3d3ec287562821b17cf2"},
+ {file = "torchvision-0.8.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:24db8f4c3d812a032273f68563ad5dbd724f5bfbed523d0c6dce8cede26bb153"},
+ {file = "torchvision-0.8.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:b068f6bcbe91bdd34dda0a39e8a26392add45a3be82543f6dd523b76484fb56f"},
+ {file = "torchvision-0.8.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:afb76a66b9b0693f758a881a2bf333ed97e3c0c3f15a413c4f49d8dd8bd21307"},
+ {file = "torchvision-0.8.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cd8817e9197fc60ebae37162a445db90bbf35591314a5767ad3d1490b5d65b0f"},
+ {file = "torchvision-0.8.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1bd58acc3366ec02266aae56a7a752d43ef07de4a6ba420c4f907d0c9168bb8c"},
+ {file = "torchvision-0.8.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:976750a49db2e23dc5a1ed0b5c31f7af51ed2702eee410ee09ef985c3a3e48cf"},
]
tornado = [
{file = "tornado-6.1-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:d371e811d6b156d82aa5f9a4e08b58debf97c302a35714f6f45e35139c332e32"},
@@ -3149,19 +3134,28 @@ typed-ast = [
{file = "typed_ast-1.4.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:269151951236b0f9a6f04015a9004084a5ab0d5f19b57de779f908621e7d8b75"},
{file = "typed_ast-1.4.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:24995c843eb0ad11a4527b026b4dde3da70e1f2d8806c99b7b4a7cf491612652"},
{file = "typed_ast-1.4.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:fe460b922ec15dd205595c9b5b99e2f056fd98ae8f9f56b888e7a17dc2b757e7"},
+ {file = "typed_ast-1.4.1-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:fcf135e17cc74dbfbc05894ebca928ffeb23d9790b3167a674921db19082401f"},
{file = "typed_ast-1.4.1-cp36-cp36m-win32.whl", hash = "sha256:4e3e5da80ccbebfff202a67bf900d081906c358ccc3d5e3c8aea42fdfdfd51c1"},
{file = "typed_ast-1.4.1-cp36-cp36m-win_amd64.whl", hash = "sha256:249862707802d40f7f29f6e1aad8d84b5aa9e44552d2cc17384b209f091276aa"},
{file = "typed_ast-1.4.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8ce678dbaf790dbdb3eba24056d5364fb45944f33553dd5869b7580cdbb83614"},
{file = "typed_ast-1.4.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:c9e348e02e4d2b4a8b2eedb48210430658df6951fa484e59de33ff773fbd4b41"},
{file = "typed_ast-1.4.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:bcd3b13b56ea479b3650b82cabd6b5343a625b0ced5429e4ccad28a8973f301b"},
+ {file = "typed_ast-1.4.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:f208eb7aff048f6bea9586e61af041ddf7f9ade7caed625742af423f6bae3298"},
{file = "typed_ast-1.4.1-cp37-cp37m-win32.whl", hash = "sha256:d5d33e9e7af3b34a40dc05f498939f0ebf187f07c385fd58d591c533ad8562fe"},
{file = "typed_ast-1.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:0666aa36131496aed8f7be0410ff974562ab7eeac11ef351def9ea6fa28f6355"},
{file = "typed_ast-1.4.1-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:d205b1b46085271b4e15f670058ce182bd1199e56b317bf2ec004b6a44f911f6"},
{file = "typed_ast-1.4.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:6daac9731f172c2a22ade6ed0c00197ee7cc1221aa84cfdf9c31defeb059a907"},
{file = "typed_ast-1.4.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:498b0f36cc7054c1fead3d7fc59d2150f4d5c6c56ba7fb150c013fbc683a8d2d"},
+ {file = "typed_ast-1.4.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:7e4c9d7658aaa1fc80018593abdf8598bf91325af6af5cce4ce7c73bc45ea53d"},
{file = "typed_ast-1.4.1-cp38-cp38-win32.whl", hash = "sha256:715ff2f2df46121071622063fc7543d9b1fd19ebfc4f5c8895af64a77a8c852c"},
{file = "typed_ast-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:fc0fea399acb12edbf8a628ba8d2312f583bdbdb3335635db062fa98cf71fca4"},
{file = "typed_ast-1.4.1-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:d43943ef777f9a1c42bf4e552ba23ac77a6351de620aa9acf64ad54933ad4d34"},
+ {file = "typed_ast-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:92c325624e304ebf0e025d1224b77dd4e6393f18aab8d829b5b7e04afe9b7a2c"},
+ {file = "typed_ast-1.4.1-cp39-cp39-manylinux1_i686.whl", hash = "sha256:d648b8e3bf2fe648745c8ffcee3db3ff903d0817a01a12dd6a6ea7a8f4889072"},
+ {file = "typed_ast-1.4.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:fac11badff8313e23717f3dada86a15389d0708275bddf766cca67a84ead3e91"},
+ {file = "typed_ast-1.4.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:0d8110d78a5736e16e26213114a38ca35cb15b6515d535413b090bd50951556d"},
+ {file = "typed_ast-1.4.1-cp39-cp39-win32.whl", hash = "sha256:b52ccf7cfe4ce2a1064b18594381bccf4179c2ecf7f513134ec2f993dd4ab395"},
+ {file = "typed_ast-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:3742b32cf1c6ef124d57f95be609c473d7ec4c14d0090e5a5e05a15269fb4d0c"},
{file = "typed_ast-1.4.1.tar.gz", hash = "sha256:8c8aaad94455178e3187ab22c8b01a3837f8ee50e09cf31f1ba129eb293ec30b"},
]
typeguard = [
diff --git a/pyproject.toml b/pyproject.toml
index 4c674bc..2f774b2 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -20,7 +20,7 @@ marshmallow = "^3.6.0"
sphinx-autodoc-typehints = "^1.10.3"
sphinx_rtd_theme = "^0.4.3"
boltons = "^20.1.0"
-h5py = "^2.10.0"
+h5py = "^3.2.1"
toml = "^0.10.1"
torch = "^1.7.0"
torchvision = "^0.8.1"
diff --git a/src/notebooks/00-testing-stuff-out.ipynb b/src/notebooks/00-testing-stuff-out.ipynb
deleted file mode 100644
index 2d6b43c..0000000
--- a/src/notebooks/00-testing-stuff-out.ipynb
+++ /dev/null
@@ -1,1059 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "%load_ext autoreload\n",
- "%autoreload 2\n",
- "\n",
- "%matplotlib inline\n",
- "import matplotlib.pyplot as plt\n",
- "import numpy as np\n",
- "from PIL import Image\n",
- "import torch.nn.functional as F\n",
- "import torch\n",
- "from torch import nn\n",
- "from torchsummary import summary\n",
- "from importlib.util import find_spec\n",
- "if find_spec(\"text_recognizer\") is None:\n",
- " import sys\n",
- " sys.path.append('..')"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from text_recognizer.networks import CNN, TDS2d"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "tds2d = TDS2d(**{\n",
- " \"depth\" : 4,\n",
- " \"tds_groups\" : [\n",
- " { \"channels\" : 4, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n",
- " { \"channels\" : 32, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n",
- " { \"channels\" : 64, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n",
- " { \"channels\" : 128, \"num_blocks\" : 3, \"stride\" : [2, 1] },\n",
- " ],\n",
- " \"kernel_size\" : [5, 7],\n",
- " \"dropout_rate\" : 0.1\n",
- " }, input_dim=32, output_dim=128)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "tds2d"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "summary(tds2d, (1, 28, 952), device=\"cpu\", depth=3)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "t = torch.randn(2,1, 28, 952)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "tds2d(t).shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "cnn = CNN()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "i = nn.Sequential(nn.Conv2d(1,1,1,1))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "nn.Sequential(i,i)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "cnn(t).shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from text_recognizer.networks.vqvae import Encoder, Decoder, VQVAE"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "vqvae = VQVAE(1, [32, 128, 128, 256], [4, 4, 4, 4], [2, 2, [1, 2], [1, 2]], 2, 32, 256, [[6, 119], [7, 238]])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "t = torch.randn(2, 1, 28, 952)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "x, l = vqvae(t)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "5 * 59 / 10"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "x.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "summary(vqvae, (1, 28, 952), device=\"cpu\", depth=3)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "up = nn.Upsample([4, 59])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "up(tt).shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "tt.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "class GEGLU(nn.Module):\n",
- " def __init__(self, dim_in, dim_out):\n",
- " super().__init__()\n",
- " self.proj = nn.Linear(dim_in, dim_out * 2)\n",
- "\n",
- " def forward(self, x):\n",
- " x, gate = self.proj(x).chunk(2, dim = -1)\n",
- " return x * F.gelu(gate)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "e = GEGLU(256, 2048)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "e(t).shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "emb = nn.Embedding(56, 256)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "with torch.no_grad():\n",
- " e = emb(torch.Tensor([55]).long())"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from einops import repeat"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "ee = repeat(e, \"() n -> b n\", b=16)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "emb.device"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "ee"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "ee.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "t = torch.randn(16, 10, 256)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "t.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "t = torch.cat((ee.unsqueeze(1), t, ee.unsqueeze(1)), dim=1)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "t.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "e.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from text_recognizer.networks.residual_network import IdentityBlock, ResidualBlock, BasicBlock, BottleNeckBlock, ResidualLayer, ResidualNetwork, ResidualNetworkEncoder"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from text_recognizer.networks import WideResidualNetwork"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "wr = WideResidualNetwork(\n",
- " in_channels= 1,\n",
- " num_classes= 80,\n",
- " in_planes=64,\n",
- " depth=10,\n",
- " num_layers=4,\n",
- " width_factor=2,\n",
- " num_stages=[64, 128, 256, 256],\n",
- " dropout_rate= 0.1,\n",
- " activation= \"SELU\",\n",
- " use_decoder= False,\n",
- ")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from torchsummary import summary"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "backbone = ResidualNetworkEncoder(1, [64, 65, 66, 67, 68], [2, 2, 2, 2, 2])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "summary(backbone, (1, 28, 952), device=\"cpu\", depth=3)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- " backbone = nn.Sequential(\n",
- " *list(wr.children())[:][:]\n",
- " )\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "backbone"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "summary(wr, (1, 28, 952), device=\"cpu\", depth=3)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "a = torch.rand(1, 1, 28, 952)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "b = wr(a)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from einops import rearrange"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "b = rearrange(b, \"b c h w -> b w c h\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "c = nn.AdaptiveAvgPool2d((None, 1))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "d = c(b)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "d.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "d.squeeze(3).shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "b.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from torch import nn"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "32 + 64"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "3 * 112"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "col_embed = nn.Parameter(torch.rand(1000, 256 // 2))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "W, H = 196, 4"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "col_embed[:W].unsqueeze(0).repeat(H, 1, 1).shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "col_embed[:H].unsqueeze(1).repeat(1, W, 1).shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- " torch.cat(\n",
- " [\n",
- " col_embed[:W].unsqueeze(0).repeat(H, 1, 1),\n",
- " col_embed[:H].unsqueeze(1).repeat(1, W, 1),\n",
- " ],\n",
- " dim=-1,\n",
- " ).unsqueeze(0).shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "4 * 196"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "target = torch.tensor([1,1,12,1,1,1,1,1,9,9,9,9,9,9])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "torch.nonzero(target == 9, as_tuple=False)[0].item()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "target[:9]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "np.inf"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from text_recognizer.networks.transformer.positional_encoding import PositionalEncoding"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "plt.figure(figsize=(15, 5))\n",
- "pe = PositionalEncoding(20, 0)\n",
- "y = pe.forward(torch.zeros(1, 100, 20))\n",
- "plt.plot(np.arange(100), y[0, :, 4:8].data.numpy())\n",
- "plt.legend([\"dim %d\"%p for p in [4,5,6,7]])\n",
- "None"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from text_recognizer.networks.densenet import DenseNet,_DenseLayer,_DenseBlock"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "dnet = DenseNet(12, (6, 12, 10), 1, 24, 80, 4, 0, True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "216 / 8"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "summary(dnet, (1, 28, 952), device=\"cpu\", depth=3)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- " backbone = nn.Sequential(\n",
- " *list(dnet.children())[:][:-4]\n",
- " )"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "backbone"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from text_recognizer.networks import WideResidualNetwork"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "w = WideResidualNetwork(\n",
- " in_channels = 1,\n",
- " in_planes = 32,\n",
- " num_classes = 80,\n",
- " depth = 10,\n",
- " width_factor = 1,\n",
- " dropout_rate = 0.0,\n",
- " num_layers = 5,\n",
- " activation = \"relu\",\n",
- " use_decoder = False,)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "summary(w, (1, 28, 952), device=\"cpu\", depth=2)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "sz= 5"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "mask = torch.triu(torch.ones(sz, sz), 1)\n",
- "mask = mask.masked_fill(mask==1, float('-inf'))"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "\n",
- "h = torch.rand(1, 256, 10, 10)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "h.flatten(2).permute(2, 0, 1).shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "h.flatten(2).permute(2, 0, 1).shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "mask\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "pred = torch.Tensor([1,21,2,45,31, 81, 1, 79, 79, 79, 2,1,1,1,1, 81, 1, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()\n",
- "target = torch.Tensor([1,1,1,1,1, 81, 79, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "mask = (target != 79)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "mask"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "pred * mask"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "target * mask"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "from text_recognizer.models.metrics import accuracy"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "pad_indcies = torch.nonzero(target == 79, as_tuple=False)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "t1 = torch.nonzero(target == 81, as_tuple=False).squeeze(1)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "target.shape[0]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "t2 = torch.arange(10, target.shape[0] + 1, 10)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "t2"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "for start, stop in zip(t1, t2):\n",
- " pred[start+1:stop] = 79"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "pred"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "[pred[start+1:stop] = 79 for start, stop in zip(t1, t2)]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "scrolled": true
- },
- "outputs": [],
- "source": [
- "pad_indcies"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "pred[pad_indcies:pad_indcies] = 79"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "pred.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "target.shape"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "accuracy(pred, target)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "acc = (pred == target).sum().float() / target.shape[0]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "acc"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": []
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.8.2"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
diff --git a/src/text_recognizer/tests/support/emnist/8.png b/src/text_recognizer/tests/support/emnist/8.png
deleted file mode 100644
index faa29aa..0000000
--- a/src/text_recognizer/tests/support/emnist/8.png
+++ /dev/null
Binary files differ
diff --git a/src/text_recognizer/tests/support/emnist/U.png b/src/text_recognizer/tests/support/emnist/U.png
deleted file mode 100644
index 304eaec..0000000
--- a/src/text_recognizer/tests/support/emnist/U.png
+++ /dev/null
Binary files differ
diff --git a/src/text_recognizer/tests/support/emnist/e.png b/src/text_recognizer/tests/support/emnist/e.png
deleted file mode 100644
index a03ecd4..0000000
--- a/src/text_recognizer/tests/support/emnist/e.png
+++ /dev/null
Binary files differ
diff --git a/src/tasks/build_transitions.py b/tasks/build_transitions.py
index 91f8c1a..91f8c1a 100644
--- a/src/tasks/build_transitions.py
+++ b/tasks/build_transitions.py
diff --git a/src/tasks/create_emnist_lines_datasets.sh b/tasks/create_emnist_lines_datasets.sh
index 6416277..6416277 100755
--- a/src/tasks/create_emnist_lines_datasets.sh
+++ b/tasks/create_emnist_lines_datasets.sh
diff --git a/src/tasks/create_iam_paragraphs.sh b/tasks/create_iam_paragraphs.sh
index fa2bfb0..fa2bfb0 100755
--- a/src/tasks/create_iam_paragraphs.sh
+++ b/tasks/create_iam_paragraphs.sh
diff --git a/src/tasks/download_emnist.sh b/tasks/download_emnist.sh
index 18c8e29..18c8e29 100755
--- a/src/tasks/download_emnist.sh
+++ b/tasks/download_emnist.sh
diff --git a/src/tasks/download_iam.sh b/tasks/download_iam.sh
index e3cf76b..e3cf76b 100755
--- a/src/tasks/download_iam.sh
+++ b/tasks/download_iam.sh
diff --git a/src/tasks/make_wordpieces.py b/tasks/make_wordpieces.py
index 2ac0e2c..2ac0e2c 100644
--- a/src/tasks/make_wordpieces.py
+++ b/tasks/make_wordpieces.py
diff --git a/src/tasks/prepare_experiments.sh b/tasks/prepare_experiments.sh
index 95a538f..95a538f 100755
--- a/src/tasks/prepare_experiments.sh
+++ b/tasks/prepare_experiments.sh
diff --git a/src/tasks/test_functionality.sh b/tasks/test_functionality.sh
index 5ccf0cd..5ccf0cd 100755
--- a/src/tasks/test_functionality.sh
+++ b/tasks/test_functionality.sh
diff --git a/src/tasks/train.sh b/tasks/train.sh
index 60cbd23..60cbd23 100755
--- a/src/tasks/train.sh
+++ b/tasks/train.sh
diff --git a/src/text_recognizer/__init__.py b/text_recognizer/__init__.py
index 3dc1f76..3dc1f76 100644
--- a/src/text_recognizer/__init__.py
+++ b/text_recognizer/__init__.py
diff --git a/src/text_recognizer/character_predictor.py b/text_recognizer/character_predictor.py
index ad71289..ad71289 100644
--- a/src/text_recognizer/character_predictor.py
+++ b/text_recognizer/character_predictor.py
diff --git a/src/text_recognizer/datasets/__init__.py b/text_recognizer/datasets/__init__.py
index a6c1c59..a6c1c59 100644
--- a/src/text_recognizer/datasets/__init__.py
+++ b/text_recognizer/datasets/__init__.py
diff --git a/src/text_recognizer/datasets/dataset.py b/text_recognizer/datasets/dataset.py
index e794605..e794605 100644
--- a/src/text_recognizer/datasets/dataset.py
+++ b/text_recognizer/datasets/dataset.py
diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/text_recognizer/datasets/emnist_dataset.py
index 9884fdf..9884fdf 100644
--- a/src/text_recognizer/datasets/emnist_dataset.py
+++ b/text_recognizer/datasets/emnist_dataset.py
diff --git a/src/text_recognizer/datasets/emnist_essentials.json b/text_recognizer/datasets/emnist_essentials.json
index 2a0648a..2a0648a 100644
--- a/src/text_recognizer/datasets/emnist_essentials.json
+++ b/text_recognizer/datasets/emnist_essentials.json
diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/text_recognizer/datasets/emnist_lines_dataset.py
index 1992446..1992446 100644
--- a/src/text_recognizer/datasets/emnist_lines_dataset.py
+++ b/text_recognizer/datasets/emnist_lines_dataset.py
diff --git a/src/text_recognizer/datasets/iam_dataset.py b/text_recognizer/datasets/iam_dataset.py
index f4a869d..a8998b9 100644
--- a/src/text_recognizer/datasets/iam_dataset.py
+++ b/text_recognizer/datasets/iam_dataset.py
@@ -13,6 +13,7 @@ from text_recognizer.datasets.util import _download_raw_dataset, DATA_DIRNAME
RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam"
METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml"
EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb"
+RAW_DATA_DIRNAME.mkdir(parents=True, exist_ok=True)
DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be.
LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates.
diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/text_recognizer/datasets/iam_lines_dataset.py
index 1cb84bd..1cb84bd 100644
--- a/src/text_recognizer/datasets/iam_lines_dataset.py
+++ b/text_recognizer/datasets/iam_lines_dataset.py
diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/text_recognizer/datasets/iam_paragraphs_dataset.py
index 8ba5142..8ba5142 100644
--- a/src/text_recognizer/datasets/iam_paragraphs_dataset.py
+++ b/text_recognizer/datasets/iam_paragraphs_dataset.py
diff --git a/src/text_recognizer/datasets/iam_preprocessor.py b/text_recognizer/datasets/iam_preprocessor.py
index a93eb00..a93eb00 100644
--- a/src/text_recognizer/datasets/iam_preprocessor.py
+++ b/text_recognizer/datasets/iam_preprocessor.py
diff --git a/src/text_recognizer/datasets/sentence_generator.py b/text_recognizer/datasets/sentence_generator.py
index dd76652..dd76652 100644
--- a/src/text_recognizer/datasets/sentence_generator.py
+++ b/text_recognizer/datasets/sentence_generator.py
diff --git a/src/text_recognizer/datasets/transforms.py b/text_recognizer/datasets/transforms.py
index b6a48f5..b6a48f5 100644
--- a/src/text_recognizer/datasets/transforms.py
+++ b/text_recognizer/datasets/transforms.py
diff --git a/src/text_recognizer/datasets/util.py b/text_recognizer/datasets/util.py
index da87756..da87756 100644
--- a/src/text_recognizer/datasets/util.py
+++ b/text_recognizer/datasets/util.py
diff --git a/src/text_recognizer/line_predictor.py b/text_recognizer/line_predictor.py
index 8e348fe..8e348fe 100644
--- a/src/text_recognizer/line_predictor.py
+++ b/text_recognizer/line_predictor.py
diff --git a/src/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py
index 7647d7e..7647d7e 100644
--- a/src/text_recognizer/models/__init__.py
+++ b/text_recognizer/models/__init__.py
diff --git a/src/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 70f4cdb..70f4cdb 100644
--- a/src/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
diff --git a/src/text_recognizer/models/character_model.py b/text_recognizer/models/character_model.py
index f9944f3..f9944f3 100644
--- a/src/text_recognizer/models/character_model.py
+++ b/text_recognizer/models/character_model.py
diff --git a/src/text_recognizer/models/crnn_model.py b/text_recognizer/models/crnn_model.py
index 1e01a83..1e01a83 100644
--- a/src/text_recognizer/models/crnn_model.py
+++ b/text_recognizer/models/crnn_model.py
diff --git a/src/text_recognizer/models/ctc_transformer_model.py b/text_recognizer/models/ctc_transformer_model.py
index 25925f2..25925f2 100644
--- a/src/text_recognizer/models/ctc_transformer_model.py
+++ b/text_recognizer/models/ctc_transformer_model.py
diff --git a/src/text_recognizer/models/segmentation_model.py b/text_recognizer/models/segmentation_model.py
index 613108a..613108a 100644
--- a/src/text_recognizer/models/segmentation_model.py
+++ b/text_recognizer/models/segmentation_model.py
diff --git a/src/text_recognizer/models/transformer_model.py b/text_recognizer/models/transformer_model.py
index 3f63053..3f63053 100644
--- a/src/text_recognizer/models/transformer_model.py
+++ b/text_recognizer/models/transformer_model.py
diff --git a/src/text_recognizer/models/vqvae_model.py b/text_recognizer/models/vqvae_model.py
index 70f6f1f..70f6f1f 100644
--- a/src/text_recognizer/models/vqvae_model.py
+++ b/text_recognizer/models/vqvae_model.py
diff --git a/src/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py
index 1521355..1521355 100644
--- a/src/text_recognizer/networks/__init__.py
+++ b/text_recognizer/networks/__init__.py
diff --git a/src/text_recognizer/networks/beam.py b/text_recognizer/networks/beam.py
index dccccdb..dccccdb 100644
--- a/src/text_recognizer/networks/beam.py
+++ b/text_recognizer/networks/beam.py
diff --git a/src/text_recognizer/networks/cnn.py b/text_recognizer/networks/cnn.py
index 1807bb9..1807bb9 100644
--- a/src/text_recognizer/networks/cnn.py
+++ b/text_recognizer/networks/cnn.py
diff --git a/src/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py
index a2d7926..9150b55 100644
--- a/src/text_recognizer/networks/cnn_transformer.py
+++ b/text_recognizer/networks/cnn_transformer.py
@@ -1,7 +1,7 @@
"""A CNN-Transformer for image to text recognition."""
from typing import Dict, Optional, Tuple
-from einops import rearrange, repeat
+from einops import rearrange
import torch
from torch import nn
from torch import Tensor
diff --git a/src/text_recognizer/networks/crnn.py b/text_recognizer/networks/crnn.py
index 778e232..778e232 100644
--- a/src/text_recognizer/networks/crnn.py
+++ b/text_recognizer/networks/crnn.py
diff --git a/src/text_recognizer/networks/ctc.py b/text_recognizer/networks/ctc.py
index af9b700..af9b700 100644
--- a/src/text_recognizer/networks/ctc.py
+++ b/text_recognizer/networks/ctc.py
diff --git a/src/text_recognizer/networks/densenet.py b/text_recognizer/networks/densenet.py
index 7dc58d9..7dc58d9 100644
--- a/src/text_recognizer/networks/densenet.py
+++ b/text_recognizer/networks/densenet.py
diff --git a/src/text_recognizer/networks/lenet.py b/text_recognizer/networks/lenet.py
index 527e1a0..527e1a0 100644
--- a/src/text_recognizer/networks/lenet.py
+++ b/text_recognizer/networks/lenet.py
diff --git a/src/text_recognizer/networks/loss/__init__.py b/text_recognizer/networks/loss/__init__.py
index b489264..b489264 100644
--- a/src/text_recognizer/networks/loss/__init__.py
+++ b/text_recognizer/networks/loss/__init__.py
diff --git a/src/text_recognizer/networks/loss/loss.py b/text_recognizer/networks/loss/loss.py
index cf9fa0d..cf9fa0d 100644
--- a/src/text_recognizer/networks/loss/loss.py
+++ b/text_recognizer/networks/loss/loss.py
diff --git a/src/text_recognizer/networks/metrics.py b/text_recognizer/networks/metrics.py
index 2605731..2605731 100644
--- a/src/text_recognizer/networks/metrics.py
+++ b/text_recognizer/networks/metrics.py
diff --git a/src/text_recognizer/networks/mlp.py b/text_recognizer/networks/mlp.py
index 1101912..1101912 100644
--- a/src/text_recognizer/networks/mlp.py
+++ b/text_recognizer/networks/mlp.py
diff --git a/src/text_recognizer/networks/residual_network.py b/text_recognizer/networks/residual_network.py
index c33f419..c33f419 100644
--- a/src/text_recognizer/networks/residual_network.py
+++ b/text_recognizer/networks/residual_network.py
diff --git a/src/text_recognizer/networks/stn.py b/text_recognizer/networks/stn.py
index e9d216f..e9d216f 100644
--- a/src/text_recognizer/networks/stn.py
+++ b/text_recognizer/networks/stn.py
diff --git a/src/text_recognizer/networks/transducer/__init__.py b/text_recognizer/networks/transducer/__init__.py
index 8c19a01..8c19a01 100644
--- a/src/text_recognizer/networks/transducer/__init__.py
+++ b/text_recognizer/networks/transducer/__init__.py
diff --git a/src/text_recognizer/networks/transducer/tds_conv.py b/text_recognizer/networks/transducer/tds_conv.py
index 5fb8ba9..5fb8ba9 100644
--- a/src/text_recognizer/networks/transducer/tds_conv.py
+++ b/text_recognizer/networks/transducer/tds_conv.py
diff --git a/src/text_recognizer/networks/transducer/test.py b/text_recognizer/networks/transducer/test.py
index cadcecc..cadcecc 100644
--- a/src/text_recognizer/networks/transducer/test.py
+++ b/text_recognizer/networks/transducer/test.py
diff --git a/src/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py
index d7e3d08..d7e3d08 100644
--- a/src/text_recognizer/networks/transducer/transducer.py
+++ b/text_recognizer/networks/transducer/transducer.py
diff --git a/src/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py
index 9febc88..9febc88 100644
--- a/src/text_recognizer/networks/transformer/__init__.py
+++ b/text_recognizer/networks/transformer/__init__.py
diff --git a/src/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py
index cce1ecc..cce1ecc 100644
--- a/src/text_recognizer/networks/transformer/attention.py
+++ b/text_recognizer/networks/transformer/attention.py
diff --git a/src/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py
index 1ba5537..1ba5537 100644
--- a/src/text_recognizer/networks/transformer/positional_encoding.py
+++ b/text_recognizer/networks/transformer/positional_encoding.py
diff --git a/src/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py
index dd180c4..dd180c4 100644
--- a/src/text_recognizer/networks/transformer/transformer.py
+++ b/text_recognizer/networks/transformer/transformer.py
diff --git a/src/text_recognizer/networks/unet.py b/text_recognizer/networks/unet.py
index 510910f..510910f 100644
--- a/src/text_recognizer/networks/unet.py
+++ b/text_recognizer/networks/unet.py
diff --git a/src/text_recognizer/networks/util.py b/text_recognizer/networks/util.py
index 131a6b4..131a6b4 100644
--- a/src/text_recognizer/networks/util.py
+++ b/text_recognizer/networks/util.py
diff --git a/src/text_recognizer/networks/vit.py b/text_recognizer/networks/vit.py
index efb3701..efb3701 100644
--- a/src/text_recognizer/networks/vit.py
+++ b/text_recognizer/networks/vit.py
diff --git a/src/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py
index c673d96..c673d96 100644
--- a/src/text_recognizer/networks/vq_transformer.py
+++ b/text_recognizer/networks/vq_transformer.py
diff --git a/src/text_recognizer/networks/vqvae/__init__.py b/text_recognizer/networks/vqvae/__init__.py
index 763953c..763953c 100644
--- a/src/text_recognizer/networks/vqvae/__init__.py
+++ b/text_recognizer/networks/vqvae/__init__.py
diff --git a/src/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py
index 8847aba..8847aba 100644
--- a/src/text_recognizer/networks/vqvae/decoder.py
+++ b/text_recognizer/networks/vqvae/decoder.py
diff --git a/src/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py
index d3adac5..d3adac5 100644
--- a/src/text_recognizer/networks/vqvae/encoder.py
+++ b/text_recognizer/networks/vqvae/encoder.py
diff --git a/src/text_recognizer/networks/vqvae/vector_quantizer.py b/text_recognizer/networks/vqvae/vector_quantizer.py
index f92c7ee..f92c7ee 100644
--- a/src/text_recognizer/networks/vqvae/vector_quantizer.py
+++ b/text_recognizer/networks/vqvae/vector_quantizer.py
diff --git a/src/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py
index 50448b4..50448b4 100644
--- a/src/text_recognizer/networks/vqvae/vqvae.py
+++ b/text_recognizer/networks/vqvae/vqvae.py
diff --git a/src/text_recognizer/networks/wide_resnet.py b/text_recognizer/networks/wide_resnet.py
index b767778..b767778 100644
--- a/src/text_recognizer/networks/wide_resnet.py
+++ b/text_recognizer/networks/wide_resnet.py
diff --git a/src/text_recognizer/paragraph_text_recognizer.py b/text_recognizer/paragraph_text_recognizer.py
index aa39662..aa39662 100644
--- a/src/text_recognizer/paragraph_text_recognizer.py
+++ b/text_recognizer/paragraph_text_recognizer.py
diff --git a/src/text_recognizer/tests/__init__.py b/text_recognizer/tests/__init__.py
index 18ff212..18ff212 100644
--- a/src/text_recognizer/tests/__init__.py
+++ b/text_recognizer/tests/__init__.py
diff --git a/src/text_recognizer/tests/support/__init__.py b/text_recognizer/tests/support/__init__.py
index a265ede..a265ede 100644
--- a/src/text_recognizer/tests/support/__init__.py
+++ b/text_recognizer/tests/support/__init__.py
diff --git a/src/text_recognizer/tests/support/create_emnist_lines_support_files.py b/text_recognizer/tests/support/create_emnist_lines_support_files.py
index 9abe143..9abe143 100644
--- a/src/text_recognizer/tests/support/create_emnist_lines_support_files.py
+++ b/text_recognizer/tests/support/create_emnist_lines_support_files.py
diff --git a/src/text_recognizer/tests/support/create_emnist_support_files.py b/text_recognizer/tests/support/create_emnist_support_files.py
index f9ff030..f9ff030 100644
--- a/src/text_recognizer/tests/support/create_emnist_support_files.py
+++ b/text_recognizer/tests/support/create_emnist_support_files.py
diff --git a/src/text_recognizer/tests/support/create_iam_lines_support_files.py b/text_recognizer/tests/support/create_iam_lines_support_files.py
index 50f9e3d..50f9e3d 100644
--- a/src/text_recognizer/tests/support/create_iam_lines_support_files.py
+++ b/text_recognizer/tests/support/create_iam_lines_support_files.py
diff --git a/src/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png b/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png
index b7d0618..b7d0618 100644
--- a/src/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png
+++ b/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png
Binary files differ
diff --git a/src/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png b/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png
index 14a8cf3..14a8cf3 100644
--- a/src/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png
+++ b/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png
Binary files differ
diff --git a/src/text_recognizer/tests/support/emnist_lines/they<eos>.png b/text_recognizer/tests/support/emnist_lines/they<eos>.png
index 7f05951..7f05951 100644
--- a/src/text_recognizer/tests/support/emnist_lines/they<eos>.png
+++ b/text_recognizer/tests/support/emnist_lines/they<eos>.png
Binary files differ
diff --git a/src/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png b/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png
index 6eeb642..6eeb642 100644
--- a/src/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png
+++ b/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png
Binary files differ
diff --git a/src/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png b/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png
index 4974cf8..4974cf8 100644
--- a/src/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png
+++ b/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png
Binary files differ
diff --git a/src/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png b/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png
index a731245..a731245 100644
--- a/src/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png
+++ b/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png
Binary files differ
diff --git a/src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg b/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg
index d9753b6..d9753b6 100644
--- a/src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg
+++ b/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg
Binary files differ
diff --git a/src/text_recognizer/tests/test_character_predictor.py b/text_recognizer/tests/test_character_predictor.py
index 01bda78..01bda78 100644
--- a/src/text_recognizer/tests/test_character_predictor.py
+++ b/text_recognizer/tests/test_character_predictor.py
diff --git a/src/text_recognizer/tests/test_line_predictor.py b/text_recognizer/tests/test_line_predictor.py
index eede4d4..eede4d4 100644
--- a/src/text_recognizer/tests/test_line_predictor.py
+++ b/text_recognizer/tests/test_line_predictor.py
diff --git a/src/text_recognizer/tests/test_paragraph_text_recognizer.py b/text_recognizer/tests/test_paragraph_text_recognizer.py
index 3e280b9..3e280b9 100644
--- a/src/text_recognizer/tests/test_paragraph_text_recognizer.py
+++ b/text_recognizer/tests/test_paragraph_text_recognizer.py
diff --git a/src/text_recognizer/util.py b/text_recognizer/util.py
index b431e22..b431e22 100644
--- a/src/text_recognizer/util.py
+++ b/text_recognizer/util.py
diff --git a/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt b/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt
index 344e0a3..344e0a3 100644
--- a/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt
+++ b/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt b/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt
index f2dfd84..f2dfd84 100644
--- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt
+++ b/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt
diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt b/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt
index e1add8d..e1add8d 100644
--- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt
+++ b/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt
diff --git a/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt b/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt
index d9ca01d..d9ca01d 100644
--- a/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt
+++ b/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt b/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt
index 0af0e57..0af0e57 100644
--- a/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt
+++ b/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt
Binary files differ
diff --git a/src/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt b/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt
index b5295c2..b5295c2 100644
--- a/src/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt
+++ b/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt
Binary files differ
diff --git a/src/training/experiments/default_config_emnist.yml b/training/experiments/default_config_emnist.yml
index bf2ed0a..bf2ed0a 100644
--- a/src/training/experiments/default_config_emnist.yml
+++ b/training/experiments/default_config_emnist.yml
diff --git a/src/training/experiments/embedding_experiment.yml b/training/experiments/embedding_experiment.yml
index 1e5f941..1e5f941 100644
--- a/src/training/experiments/embedding_experiment.yml
+++ b/training/experiments/embedding_experiment.yml
diff --git a/src/training/experiments/sample_experiment.yml b/training/experiments/sample_experiment.yml
index 8f94475..8f94475 100644
--- a/src/training/experiments/sample_experiment.yml
+++ b/training/experiments/sample_experiment.yml
diff --git a/src/training/gpu_manager.py b/training/gpu_manager.py
index ce1b3dd..ce1b3dd 100644
--- a/src/training/gpu_manager.py
+++ b/training/gpu_manager.py
diff --git a/src/training/prepare_experiments.py b/training/prepare_experiments.py
index 21997af..21997af 100644
--- a/src/training/prepare_experiments.py
+++ b/training/prepare_experiments.py
diff --git a/src/training/run_experiment.py b/training/run_experiment.py
index faafea6..faafea6 100644
--- a/src/training/run_experiment.py
+++ b/training/run_experiment.py
diff --git a/src/training/run_sweep.py b/training/run_sweep.py
index a578592..a578592 100644
--- a/src/training/run_sweep.py
+++ b/training/run_sweep.py
diff --git a/src/training/sweep_emnist.yml b/training/sweep_emnist.yml
index 48d7261..48d7261 100644
--- a/src/training/sweep_emnist.yml
+++ b/training/sweep_emnist.yml
diff --git a/src/training/sweep_emnist_resnet.yml b/training/sweep_emnist_resnet.yml
index 19a3040..19a3040 100644
--- a/src/training/sweep_emnist_resnet.yml
+++ b/training/sweep_emnist_resnet.yml
diff --git a/src/training/trainer/__init__.py b/training/trainer/__init__.py
index de41bfb..de41bfb 100644
--- a/src/training/trainer/__init__.py
+++ b/training/trainer/__init__.py
diff --git a/src/training/trainer/callbacks/__init__.py b/training/trainer/callbacks/__init__.py
index 80c4177..80c4177 100644
--- a/src/training/trainer/callbacks/__init__.py
+++ b/training/trainer/callbacks/__init__.py
diff --git a/src/training/trainer/callbacks/base.py b/training/trainer/callbacks/base.py
index 500b642..500b642 100644
--- a/src/training/trainer/callbacks/base.py
+++ b/training/trainer/callbacks/base.py
diff --git a/src/training/trainer/callbacks/checkpoint.py b/training/trainer/callbacks/checkpoint.py
index a54e0a9..a54e0a9 100644
--- a/src/training/trainer/callbacks/checkpoint.py
+++ b/training/trainer/callbacks/checkpoint.py
diff --git a/src/training/trainer/callbacks/early_stopping.py b/training/trainer/callbacks/early_stopping.py
index 02b431f..02b431f 100644
--- a/src/training/trainer/callbacks/early_stopping.py
+++ b/training/trainer/callbacks/early_stopping.py
diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/training/trainer/callbacks/lr_schedulers.py
index 630c434..630c434 100644
--- a/src/training/trainer/callbacks/lr_schedulers.py
+++ b/training/trainer/callbacks/lr_schedulers.py
diff --git a/src/training/trainer/callbacks/progress_bar.py b/training/trainer/callbacks/progress_bar.py
index 6c4305a..6c4305a 100644
--- a/src/training/trainer/callbacks/progress_bar.py
+++ b/training/trainer/callbacks/progress_bar.py
diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/training/trainer/callbacks/wandb_callbacks.py
index 552a4f4..552a4f4 100644
--- a/src/training/trainer/callbacks/wandb_callbacks.py
+++ b/training/trainer/callbacks/wandb_callbacks.py
diff --git a/src/training/trainer/train.py b/training/trainer/train.py
index b770c94..b770c94 100644
--- a/src/training/trainer/train.py
+++ b/training/trainer/train.py
diff --git a/src/training/trainer/util.py b/training/trainer/util.py
index 7cf1b45..7cf1b45 100644
--- a/src/training/trainer/util.py
+++ b/training/trainer/util.py
diff --git a/src/wandb/settings b/wandb/settings
index eafb083..eafb083 100644
--- a/src/wandb/settings
+++ b/wandb/settings