diff options
-rw-r--r-- | .gitattributes (renamed from src/.gitattributes) | 0 | ||||
-rw-r--r-- | notebooks/00-testing-stuff-out.ipynb | 1469 | ||||
-rw-r--r-- | notebooks/01-look-at-emnist.ipynb (renamed from src/notebooks/01-look-at-emnist.ipynb) | 0 | ||||
-rw-r--r-- | notebooks/02a-sentence-generator.ipynb (renamed from src/notebooks/02a-sentence-generator.ipynb) | 0 | ||||
-rw-r--r-- | notebooks/02b-emnist-lines-dataset.ipynb (renamed from src/notebooks/02b-emnist-lines-dataset.ipynb) | 0 | ||||
-rw-r--r-- | notebooks/02c-image-patches.ipynb (renamed from src/notebooks/02c-image-patches.ipynb) | 0 | ||||
-rw-r--r-- | notebooks/03a-line-prediction.ipynb (renamed from src/notebooks/03a-line-prediction.ipynb) | 0 | ||||
-rw-r--r-- | notebooks/04a-look-at-iam-lines.ipynb (renamed from src/notebooks/04a-look-at-iam-lines.ipynb) | 0 | ||||
-rw-r--r-- | notebooks/04b-look-at-iam-paragraphs-predictions.ipynb (renamed from src/notebooks/04b-look-at-iam-paragraphs-predictions.ipynb) | 0 | ||||
-rw-r--r-- | notebooks/04b-look-at-iam-paragraphs.ipynb (renamed from src/notebooks/04b-look-at-iam-paragraphs.ipynb) | 0 | ||||
-rw-r--r-- | notebooks/05-sanity-check-multihead-attention.ipynb (renamed from src/notebooks/05-sanity-check-multihead-attention.ipynb) | 0 | ||||
-rw-r--r-- | notebooks/05a-UNet.ipynb (renamed from src/notebooks/05a-UNet.ipynb) | 0 | ||||
-rw-r--r-- | notebooks/05a-test-end-to-end-model.ipynb (renamed from src/notebooks/05a-test-end-to-end-model.ipynb) | 0 | ||||
-rw-r--r-- | notebooks/06-try-transformer-model-predictions.ipynb (renamed from src/notebooks/06-try-transformer-model-predictions.ipynb) | 0 | ||||
-rw-r--r-- | notebooks/07-look-at-lexicon.ipynb (renamed from src/notebooks/07-look-at-lexicon.ipynb) | 0 | ||||
-rw-r--r-- | notebooks/07-try-gtn.ipynb (renamed from src/notebooks/07-try-gtn.ipynb) | 0 | ||||
-rw-r--r-- | notebooks/Untitled.ipynb (renamed from src/notebooks/Untitled.ipynb) | 0 | ||||
-rw-r--r-- | notebooks/g1.png (renamed from src/notebooks/g1.png) | bin | 8590 -> 8590 bytes | |||
-rw-r--r-- | notebooks/g2.png (renamed from src/notebooks/g2.png) | bin | 5247 -> 5247 bytes | |||
-rw-r--r-- | notebooks/intersect.png (renamed from src/notebooks/intersect.png) | bin | 7953 -> 7953 bytes | |||
-rw-r--r-- | notebooks/intersection.pdf (renamed from src/notebooks/intersection.pdf) | bin | 10154 -> 10154 bytes | |||
-rw-r--r-- | noxfile.py | 20 | ||||
-rw-r--r-- | poetry.lock | 216 | ||||
-rw-r--r-- | pyproject.toml | 2 | ||||
-rw-r--r-- | src/notebooks/00-testing-stuff-out.ipynb | 1059 | ||||
-rw-r--r-- | src/text_recognizer/tests/support/emnist/8.png | bin | 498 -> 0 bytes | |||
-rw-r--r-- | src/text_recognizer/tests/support/emnist/U.png | bin | 524 -> 0 bytes | |||
-rw-r--r-- | src/text_recognizer/tests/support/emnist/e.png | bin | 563 -> 0 bytes | |||
-rw-r--r-- | tasks/build_transitions.py (renamed from src/tasks/build_transitions.py) | 0 | ||||
-rwxr-xr-x | tasks/create_emnist_lines_datasets.sh (renamed from src/tasks/create_emnist_lines_datasets.sh) | 0 | ||||
-rwxr-xr-x | tasks/create_iam_paragraphs.sh (renamed from src/tasks/create_iam_paragraphs.sh) | 0 | ||||
-rwxr-xr-x | tasks/download_emnist.sh (renamed from src/tasks/download_emnist.sh) | 0 | ||||
-rwxr-xr-x | tasks/download_iam.sh (renamed from src/tasks/download_iam.sh) | 0 | ||||
-rw-r--r-- | tasks/make_wordpieces.py (renamed from src/tasks/make_wordpieces.py) | 0 | ||||
-rwxr-xr-x | tasks/prepare_experiments.sh (renamed from src/tasks/prepare_experiments.sh) | 0 | ||||
-rwxr-xr-x | tasks/test_functionality.sh (renamed from src/tasks/test_functionality.sh) | 0 | ||||
-rwxr-xr-x | tasks/train.sh (renamed from src/tasks/train.sh) | 0 | ||||
-rw-r--r-- | text_recognizer/__init__.py (renamed from src/text_recognizer/__init__.py) | 0 | ||||
-rw-r--r-- | text_recognizer/character_predictor.py (renamed from src/text_recognizer/character_predictor.py) | 0 | ||||
-rw-r--r-- | text_recognizer/datasets/__init__.py (renamed from src/text_recognizer/datasets/__init__.py) | 0 | ||||
-rw-r--r-- | text_recognizer/datasets/dataset.py (renamed from src/text_recognizer/datasets/dataset.py) | 0 | ||||
-rw-r--r-- | text_recognizer/datasets/emnist_dataset.py (renamed from src/text_recognizer/datasets/emnist_dataset.py) | 0 | ||||
-rw-r--r-- | text_recognizer/datasets/emnist_essentials.json (renamed from src/text_recognizer/datasets/emnist_essentials.json) | 0 | ||||
-rw-r--r-- | text_recognizer/datasets/emnist_lines_dataset.py (renamed from src/text_recognizer/datasets/emnist_lines_dataset.py) | 0 | ||||
-rw-r--r-- | text_recognizer/datasets/iam_dataset.py (renamed from src/text_recognizer/datasets/iam_dataset.py) | 1 | ||||
-rw-r--r-- | text_recognizer/datasets/iam_lines_dataset.py (renamed from src/text_recognizer/datasets/iam_lines_dataset.py) | 0 | ||||
-rw-r--r-- | text_recognizer/datasets/iam_paragraphs_dataset.py (renamed from src/text_recognizer/datasets/iam_paragraphs_dataset.py) | 0 | ||||
-rw-r--r-- | text_recognizer/datasets/iam_preprocessor.py (renamed from src/text_recognizer/datasets/iam_preprocessor.py) | 0 | ||||
-rw-r--r-- | text_recognizer/datasets/sentence_generator.py (renamed from src/text_recognizer/datasets/sentence_generator.py) | 0 | ||||
-rw-r--r-- | text_recognizer/datasets/transforms.py (renamed from src/text_recognizer/datasets/transforms.py) | 0 | ||||
-rw-r--r-- | text_recognizer/datasets/util.py (renamed from src/text_recognizer/datasets/util.py) | 0 | ||||
-rw-r--r-- | text_recognizer/line_predictor.py (renamed from src/text_recognizer/line_predictor.py) | 0 | ||||
-rw-r--r-- | text_recognizer/models/__init__.py (renamed from src/text_recognizer/models/__init__.py) | 0 | ||||
-rw-r--r-- | text_recognizer/models/base.py (renamed from src/text_recognizer/models/base.py) | 0 | ||||
-rw-r--r-- | text_recognizer/models/character_model.py (renamed from src/text_recognizer/models/character_model.py) | 0 | ||||
-rw-r--r-- | text_recognizer/models/crnn_model.py (renamed from src/text_recognizer/models/crnn_model.py) | 0 | ||||
-rw-r--r-- | text_recognizer/models/ctc_transformer_model.py (renamed from src/text_recognizer/models/ctc_transformer_model.py) | 0 | ||||
-rw-r--r-- | text_recognizer/models/segmentation_model.py (renamed from src/text_recognizer/models/segmentation_model.py) | 0 | ||||
-rw-r--r-- | text_recognizer/models/transformer_model.py (renamed from src/text_recognizer/models/transformer_model.py) | 0 | ||||
-rw-r--r-- | text_recognizer/models/vqvae_model.py (renamed from src/text_recognizer/models/vqvae_model.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/__init__.py (renamed from src/text_recognizer/networks/__init__.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/beam.py (renamed from src/text_recognizer/networks/beam.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/cnn.py (renamed from src/text_recognizer/networks/cnn.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/cnn_transformer.py (renamed from src/text_recognizer/networks/cnn_transformer.py) | 2 | ||||
-rw-r--r-- | text_recognizer/networks/crnn.py (renamed from src/text_recognizer/networks/crnn.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/ctc.py (renamed from src/text_recognizer/networks/ctc.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/densenet.py (renamed from src/text_recognizer/networks/densenet.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/lenet.py (renamed from src/text_recognizer/networks/lenet.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/loss/__init__.py (renamed from src/text_recognizer/networks/loss/__init__.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/loss/loss.py (renamed from src/text_recognizer/networks/loss/loss.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/metrics.py (renamed from src/text_recognizer/networks/metrics.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/mlp.py (renamed from src/text_recognizer/networks/mlp.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/residual_network.py (renamed from src/text_recognizer/networks/residual_network.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/stn.py (renamed from src/text_recognizer/networks/stn.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/transducer/__init__.py (renamed from src/text_recognizer/networks/transducer/__init__.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/transducer/tds_conv.py (renamed from src/text_recognizer/networks/transducer/tds_conv.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/transducer/test.py (renamed from src/text_recognizer/networks/transducer/test.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/transducer/transducer.py (renamed from src/text_recognizer/networks/transducer/transducer.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/__init__.py (renamed from src/text_recognizer/networks/transformer/__init__.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/attention.py (renamed from src/text_recognizer/networks/transformer/attention.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/positional_encoding.py (renamed from src/text_recognizer/networks/transformer/positional_encoding.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/transformer/transformer.py (renamed from src/text_recognizer/networks/transformer/transformer.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/unet.py (renamed from src/text_recognizer/networks/unet.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/util.py (renamed from src/text_recognizer/networks/util.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/vit.py (renamed from src/text_recognizer/networks/vit.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/vq_transformer.py (renamed from src/text_recognizer/networks/vq_transformer.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/__init__.py (renamed from src/text_recognizer/networks/vqvae/__init__.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/decoder.py (renamed from src/text_recognizer/networks/vqvae/decoder.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/encoder.py (renamed from src/text_recognizer/networks/vqvae/encoder.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/vector_quantizer.py (renamed from src/text_recognizer/networks/vqvae/vector_quantizer.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/vqvae.py (renamed from src/text_recognizer/networks/vqvae/vqvae.py) | 0 | ||||
-rw-r--r-- | text_recognizer/networks/wide_resnet.py (renamed from src/text_recognizer/networks/wide_resnet.py) | 0 | ||||
-rw-r--r-- | text_recognizer/paragraph_text_recognizer.py (renamed from src/text_recognizer/paragraph_text_recognizer.py) | 0 | ||||
-rw-r--r-- | text_recognizer/tests/__init__.py (renamed from src/text_recognizer/tests/__init__.py) | 0 | ||||
-rw-r--r-- | text_recognizer/tests/support/__init__.py (renamed from src/text_recognizer/tests/support/__init__.py) | 0 | ||||
-rw-r--r-- | text_recognizer/tests/support/create_emnist_lines_support_files.py (renamed from src/text_recognizer/tests/support/create_emnist_lines_support_files.py) | 0 | ||||
-rw-r--r-- | text_recognizer/tests/support/create_emnist_support_files.py (renamed from src/text_recognizer/tests/support/create_emnist_support_files.py) | 0 | ||||
-rw-r--r-- | text_recognizer/tests/support/create_iam_lines_support_files.py (renamed from src/text_recognizer/tests/support/create_iam_lines_support_files.py) | 0 | ||||
-rw-r--r-- | text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png (renamed from src/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png) | bin | 2301 -> 2301 bytes | |||
-rw-r--r-- | text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png (renamed from src/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png) | bin | 5424 -> 5424 bytes | |||
-rw-r--r-- | text_recognizer/tests/support/emnist_lines/they<eos>.png (renamed from src/text_recognizer/tests/support/emnist_lines/they<eos>.png) | bin | 1391 -> 1391 bytes | |||
-rw-r--r-- | text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png (renamed from src/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png) | bin | 5170 -> 5170 bytes | |||
-rw-r--r-- | text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png (renamed from src/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png) | bin | 3617 -> 3617 bytes | |||
-rw-r--r-- | text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png (renamed from src/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png) | bin | 3923 -> 3923 bytes | |||
-rw-r--r-- | text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg (renamed from src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg) | bin | 14890 -> 14890 bytes | |||
-rw-r--r-- | text_recognizer/tests/test_character_predictor.py (renamed from src/text_recognizer/tests/test_character_predictor.py) | 0 | ||||
-rw-r--r-- | text_recognizer/tests/test_line_predictor.py (renamed from src/text_recognizer/tests/test_line_predictor.py) | 0 | ||||
-rw-r--r-- | text_recognizer/tests/test_paragraph_text_recognizer.py (renamed from src/text_recognizer/tests/test_paragraph_text_recognizer.py) | 0 | ||||
-rw-r--r-- | text_recognizer/util.py (renamed from src/text_recognizer/util.py) | 0 | ||||
-rw-r--r-- | text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt (renamed from src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt) | 0 | ||||
-rw-r--r-- | text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt (renamed from src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt) | 0 | ||||
-rw-r--r-- | text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt (renamed from src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt) | 0 | ||||
-rw-r--r-- | text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt (renamed from src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt) | bin | 8588813 -> 8588813 bytes | |||
-rw-r--r-- | text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt (renamed from src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt) | bin | 92335101 -> 92335101 bytes | |||
-rw-r--r-- | text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt (renamed from src/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt) | bin | 21687018 -> 21687018 bytes | |||
-rw-r--r-- | training/experiments/default_config_emnist.yml (renamed from src/training/experiments/default_config_emnist.yml) | 0 | ||||
-rw-r--r-- | training/experiments/embedding_experiment.yml (renamed from src/training/experiments/embedding_experiment.yml) | 0 | ||||
-rw-r--r-- | training/experiments/sample_experiment.yml (renamed from src/training/experiments/sample_experiment.yml) | 0 | ||||
-rw-r--r-- | training/gpu_manager.py (renamed from src/training/gpu_manager.py) | 0 | ||||
-rw-r--r-- | training/prepare_experiments.py (renamed from src/training/prepare_experiments.py) | 0 | ||||
-rw-r--r-- | training/run_experiment.py (renamed from src/training/run_experiment.py) | 0 | ||||
-rw-r--r-- | training/run_sweep.py (renamed from src/training/run_sweep.py) | 0 | ||||
-rw-r--r-- | training/sweep_emnist.yml (renamed from src/training/sweep_emnist.yml) | 0 | ||||
-rw-r--r-- | training/sweep_emnist_resnet.yml (renamed from src/training/sweep_emnist_resnet.yml) | 0 | ||||
-rw-r--r-- | training/trainer/__init__.py (renamed from src/training/trainer/__init__.py) | 0 | ||||
-rw-r--r-- | training/trainer/callbacks/__init__.py (renamed from src/training/trainer/callbacks/__init__.py) | 0 | ||||
-rw-r--r-- | training/trainer/callbacks/base.py (renamed from src/training/trainer/callbacks/base.py) | 0 | ||||
-rw-r--r-- | training/trainer/callbacks/checkpoint.py (renamed from src/training/trainer/callbacks/checkpoint.py) | 0 | ||||
-rw-r--r-- | training/trainer/callbacks/early_stopping.py (renamed from src/training/trainer/callbacks/early_stopping.py) | 0 | ||||
-rw-r--r-- | training/trainer/callbacks/lr_schedulers.py (renamed from src/training/trainer/callbacks/lr_schedulers.py) | 0 | ||||
-rw-r--r-- | training/trainer/callbacks/progress_bar.py (renamed from src/training/trainer/callbacks/progress_bar.py) | 0 | ||||
-rw-r--r-- | training/trainer/callbacks/wandb_callbacks.py (renamed from src/training/trainer/callbacks/wandb_callbacks.py) | 0 | ||||
-rw-r--r-- | training/trainer/train.py (renamed from src/training/trainer/train.py) | 0 | ||||
-rw-r--r-- | training/trainer/util.py (renamed from src/training/trainer/util.py) | 0 | ||||
-rw-r--r-- | wandb/settings (renamed from src/wandb/settings) | 0 |
135 files changed, 1587 insertions, 1182 deletions
diff --git a/src/.gitattributes b/.gitattributes index eebe826..eebe826 100644 --- a/src/.gitattributes +++ b/.gitattributes diff --git a/notebooks/00-testing-stuff-out.ipynb b/notebooks/00-testing-stuff-out.ipynb new file mode 100644 index 0000000..becd918 --- /dev/null +++ b/notebooks/00-testing-stuff-out.ipynb @@ -0,0 +1,1469 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from PIL import Image\n", + "import torch.nn.functional as F\n", + "import torch\n", + "from torch import nn\n", + "from torchsummary import summary\n", + "from importlib.util import find_spec\n", + "if find_spec(\"text_recognizer\") is None:\n", + " import sys\n", + " sys.path.append('..')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks import CNN, TDS2d" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "tds2d = TDS2d(**{\n", + " \"depth\" : 4,\n", + " \"tds_groups\" : [\n", + " { \"channels\" : 4, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n", + " { \"channels\" : 32, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n", + " { \"channels\" : 64, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n", + " { \"channels\" : 128, \"num_blocks\" : 3, \"stride\" : [2, 1] },\n", + " ],\n", + " \"kernel_size\" : [5, 7],\n", + " \"dropout_rate\" : 0.1\n", + " }, input_dim=32, output_dim=128)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TDS2d(\n", + " (tds): Sequential(\n", + " (0): Conv2d(1, 16, kernel_size=[5, 7], stride=[2, 2], padding=(2, 3))\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " (3): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " (4): TDSBlock2d(\n", + " (conv): Sequential(\n", + " (0): Conv3d(4, 4, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (mlp): Sequential(\n", + " (0): Linear(in_features=16, out_features=16, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " (3): Linear(in_features=16, out_features=16, bias=True)\n", + " (4): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (instance_norm): ModuleList(\n", + " (0): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " (1): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " )\n", + " )\n", + " (5): TDSBlock2d(\n", + " (conv): Sequential(\n", + " (0): Conv3d(4, 4, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (mlp): Sequential(\n", + " (0): Linear(in_features=16, out_features=16, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " (3): Linear(in_features=16, out_features=16, bias=True)\n", + " (4): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (instance_norm): ModuleList(\n", + " (0): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " (1): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " )\n", + " )\n", + " (6): TDSBlock2d(\n", + " (conv): Sequential(\n", + " (0): Conv3d(4, 4, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (mlp): Sequential(\n", + " (0): Linear(in_features=16, out_features=16, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " (3): Linear(in_features=16, out_features=16, bias=True)\n", + " (4): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (instance_norm): ModuleList(\n", + " (0): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " (1): InstanceNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " )\n", + " )\n", + " (7): Conv2d(16, 128, kernel_size=[5, 7], stride=[2, 2], padding=(2, 3))\n", + " (8): ReLU(inplace=True)\n", + " (9): Dropout(p=0.1, inplace=False)\n", + " (10): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " (11): TDSBlock2d(\n", + " (conv): Sequential(\n", + " (0): Conv3d(32, 32, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (mlp): Sequential(\n", + " (0): Linear(in_features=128, out_features=128, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " (3): Linear(in_features=128, out_features=128, bias=True)\n", + " (4): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (instance_norm): ModuleList(\n", + " (0): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " )\n", + " )\n", + " (12): TDSBlock2d(\n", + " (conv): Sequential(\n", + " (0): Conv3d(32, 32, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (mlp): Sequential(\n", + " (0): Linear(in_features=128, out_features=128, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " (3): Linear(in_features=128, out_features=128, bias=True)\n", + " (4): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (instance_norm): ModuleList(\n", + " (0): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " )\n", + " )\n", + " (13): TDSBlock2d(\n", + " (conv): Sequential(\n", + " (0): Conv3d(32, 32, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (mlp): Sequential(\n", + " (0): Linear(in_features=128, out_features=128, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " (3): Linear(in_features=128, out_features=128, bias=True)\n", + " (4): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (instance_norm): ModuleList(\n", + " (0): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " )\n", + " )\n", + " (14): Conv2d(128, 256, kernel_size=[5, 7], stride=[2, 2], padding=(2, 3))\n", + " (15): ReLU(inplace=True)\n", + " (16): Dropout(p=0.1, inplace=False)\n", + " (17): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " (18): TDSBlock2d(\n", + " (conv): Sequential(\n", + " (0): Conv3d(64, 64, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (mlp): Sequential(\n", + " (0): Linear(in_features=256, out_features=256, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " (3): Linear(in_features=256, out_features=256, bias=True)\n", + " (4): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (instance_norm): ModuleList(\n", + " (0): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " )\n", + " )\n", + " (19): TDSBlock2d(\n", + " (conv): Sequential(\n", + " (0): Conv3d(64, 64, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (mlp): Sequential(\n", + " (0): Linear(in_features=256, out_features=256, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " (3): Linear(in_features=256, out_features=256, bias=True)\n", + " (4): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (instance_norm): ModuleList(\n", + " (0): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " )\n", + " )\n", + " (20): TDSBlock2d(\n", + " (conv): Sequential(\n", + " (0): Conv3d(64, 64, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (mlp): Sequential(\n", + " (0): Linear(in_features=256, out_features=256, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " (3): Linear(in_features=256, out_features=256, bias=True)\n", + " (4): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (instance_norm): ModuleList(\n", + " (0): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " (1): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " )\n", + " )\n", + " (21): Conv2d(256, 512, kernel_size=[5, 7], stride=[2, 1], padding=(2, 3))\n", + " (22): ReLU(inplace=True)\n", + " (23): Dropout(p=0.1, inplace=False)\n", + " (24): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " (25): TDSBlock2d(\n", + " (conv): Sequential(\n", + " (0): Conv3d(128, 128, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (mlp): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " (3): Linear(in_features=512, out_features=512, bias=True)\n", + " (4): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (instance_norm): ModuleList(\n", + " (0): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " )\n", + " )\n", + " (26): TDSBlock2d(\n", + " (conv): Sequential(\n", + " (0): Conv3d(128, 128, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (mlp): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " (3): Linear(in_features=512, out_features=512, bias=True)\n", + " (4): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (instance_norm): ModuleList(\n", + " (0): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " )\n", + " )\n", + " (27): TDSBlock2d(\n", + " (conv): Sequential(\n", + " (0): Conv3d(128, 128, kernel_size=(1, 5, 7), stride=(1, 1, 1), padding=(0, 2, 3))\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (mlp): Sequential(\n", + " (0): Linear(in_features=512, out_features=512, bias=True)\n", + " (1): ReLU(inplace=True)\n", + " (2): Dropout(p=0.1, inplace=False)\n", + " (3): Linear(in_features=512, out_features=512, bias=True)\n", + " (4): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (instance_norm): ModuleList(\n", + " (0): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " (1): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)\n", + " )\n", + " )\n", + " )\n", + " (fc): Linear(in_features=1024, out_features=128, bias=True)\n", + ")" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tds2d" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "===============================================================================================\n", + "Layer (type:depth-idx) Output Shape Param #\n", + "===============================================================================================\n", + "├─Sequential: 1-1 [-1, 512, 2, 119] --\n", + "| └─Conv2d: 2-1 [-1, 16, 14, 476] 576\n", + "| └─ReLU: 2-2 [-1, 16, 14, 476] --\n", + "| └─Dropout: 2-3 [-1, 16, 14, 476] --\n", + "| └─InstanceNorm2d: 2-4 [-1, 16, 14, 476] 32\n", + "| └─TDSBlock2d: 2-5 [-1, 16, 14, 476] --\n", + "| | └─Sequential: 3-1 [-1, 4, 4, 14, 476] 564\n", + "| | └─Sequential: 3-2 [-1, 476, 14, 16] 544\n", + "| └─TDSBlock2d: 2-6 [-1, 16, 14, 476] --\n", + "| | └─Sequential: 3-3 [-1, 4, 4, 14, 476] 564\n", + "| | └─Sequential: 3-4 [-1, 476, 14, 16] 544\n", + "| └─TDSBlock2d: 2-7 [-1, 16, 14, 476] --\n", + "| | └─Sequential: 3-5 [-1, 4, 4, 14, 476] 564\n", + "| | └─Sequential: 3-6 [-1, 476, 14, 16] 544\n", + "| └─Conv2d: 2-8 [-1, 128, 7, 238] 71,808\n", + "| └─ReLU: 2-9 [-1, 128, 7, 238] --\n", + "| └─Dropout: 2-10 [-1, 128, 7, 238] --\n", + "| └─InstanceNorm2d: 2-11 [-1, 128, 7, 238] 256\n", + "| └─TDSBlock2d: 2-12 [-1, 128, 7, 238] --\n", + "| | └─Sequential: 3-7 [-1, 32, 4, 7, 238] 35,872\n", + "| | └─Sequential: 3-8 [-1, 238, 7, 128] 33,024\n", + "| └─TDSBlock2d: 2-13 [-1, 128, 7, 238] --\n", + "| | └─Sequential: 3-9 [-1, 32, 4, 7, 238] 35,872\n", + "| | └─Sequential: 3-10 [-1, 238, 7, 128] 33,024\n", + "| └─TDSBlock2d: 2-14 [-1, 128, 7, 238] --\n", + "| | └─Sequential: 3-11 [-1, 32, 4, 7, 238] 35,872\n", + "| | └─Sequential: 3-12 [-1, 238, 7, 128] 33,024\n", + "| └─Conv2d: 2-15 [-1, 256, 4, 119] 1,147,136\n", + "| └─ReLU: 2-16 [-1, 256, 4, 119] --\n", + "| └─Dropout: 2-17 [-1, 256, 4, 119] --\n", + "| └─InstanceNorm2d: 2-18 [-1, 256, 4, 119] 512\n", + "| └─TDSBlock2d: 2-19 [-1, 256, 4, 119] --\n", + "| | └─Sequential: 3-13 [-1, 64, 4, 4, 119] 143,424\n", + "| | └─Sequential: 3-14 [-1, 119, 4, 256] 131,584\n", + "| └─TDSBlock2d: 2-20 [-1, 256, 4, 119] --\n", + "| | └─Sequential: 3-15 [-1, 64, 4, 4, 119] 143,424\n", + "| | └─Sequential: 3-16 [-1, 119, 4, 256] 131,584\n", + "| └─TDSBlock2d: 2-21 [-1, 256, 4, 119] --\n", + "| | └─Sequential: 3-17 [-1, 64, 4, 4, 119] 143,424\n", + "| | └─Sequential: 3-18 [-1, 119, 4, 256] 131,584\n", + "| └─Conv2d: 2-22 [-1, 512, 2, 119] 4,588,032\n", + "| └─ReLU: 2-23 [-1, 512, 2, 119] --\n", + "| └─Dropout: 2-24 [-1, 512, 2, 119] --\n", + "| └─InstanceNorm2d: 2-25 [-1, 512, 2, 119] 1,024\n", + "| └─TDSBlock2d: 2-26 [-1, 512, 2, 119] --\n", + "| | └─Sequential: 3-19 [-1, 128, 4, 2, 119] 573,568\n", + "| | └─Sequential: 3-20 [-1, 119, 2, 512] 525,312\n", + "| └─TDSBlock2d: 2-27 [-1, 512, 2, 119] --\n", + "| | └─Sequential: 3-21 [-1, 128, 4, 2, 119] 573,568\n", + "| | └─Sequential: 3-22 [-1, 119, 2, 512] 525,312\n", + "| └─TDSBlock2d: 2-28 [-1, 512, 2, 119] --\n", + "| | └─Sequential: 3-23 [-1, 128, 4, 2, 119] 573,568\n", + "| | └─Sequential: 3-24 [-1, 119, 2, 512] 525,312\n", + "├─Linear: 1-2 [-1, 119, 128] 131,200\n", + "===============================================================================================\n", + "Total params: 10,272,252\n", + "Trainable params: 10,272,252\n", + "Non-trainable params: 0\n", + "Total mult-adds (G): 5.00\n", + "===============================================================================================\n", + "Input size (MB): 0.10\n", + "Forward/backward pass size (MB): 73.21\n", + "Params size (MB): 39.19\n", + "Estimated Total Size (MB): 112.50\n", + "===============================================================================================\n" + ] + }, + { + "data": { + "text/plain": [ + "===============================================================================================\n", + "Layer (type:depth-idx) Output Shape Param #\n", + "===============================================================================================\n", + "├─Sequential: 1-1 [-1, 512, 2, 119] --\n", + "| └─Conv2d: 2-1 [-1, 16, 14, 476] 576\n", + "| └─ReLU: 2-2 [-1, 16, 14, 476] --\n", + "| └─Dropout: 2-3 [-1, 16, 14, 476] --\n", + "| └─InstanceNorm2d: 2-4 [-1, 16, 14, 476] 32\n", + "| └─TDSBlock2d: 2-5 [-1, 16, 14, 476] --\n", + "| | └─Sequential: 3-1 [-1, 4, 4, 14, 476] 564\n", + "| | └─Sequential: 3-2 [-1, 476, 14, 16] 544\n", + "| └─TDSBlock2d: 2-6 [-1, 16, 14, 476] --\n", + "| | └─Sequential: 3-3 [-1, 4, 4, 14, 476] 564\n", + "| | └─Sequential: 3-4 [-1, 476, 14, 16] 544\n", + "| └─TDSBlock2d: 2-7 [-1, 16, 14, 476] --\n", + "| | └─Sequential: 3-5 [-1, 4, 4, 14, 476] 564\n", + "| | └─Sequential: 3-6 [-1, 476, 14, 16] 544\n", + "| └─Conv2d: 2-8 [-1, 128, 7, 238] 71,808\n", + "| └─ReLU: 2-9 [-1, 128, 7, 238] --\n", + "| └─Dropout: 2-10 [-1, 128, 7, 238] --\n", + "| └─InstanceNorm2d: 2-11 [-1, 128, 7, 238] 256\n", + "| └─TDSBlock2d: 2-12 [-1, 128, 7, 238] --\n", + "| | └─Sequential: 3-7 [-1, 32, 4, 7, 238] 35,872\n", + "| | └─Sequential: 3-8 [-1, 238, 7, 128] 33,024\n", + "| └─TDSBlock2d: 2-13 [-1, 128, 7, 238] --\n", + "| | └─Sequential: 3-9 [-1, 32, 4, 7, 238] 35,872\n", + "| | └─Sequential: 3-10 [-1, 238, 7, 128] 33,024\n", + "| └─TDSBlock2d: 2-14 [-1, 128, 7, 238] --\n", + "| | └─Sequential: 3-11 [-1, 32, 4, 7, 238] 35,872\n", + "| | └─Sequential: 3-12 [-1, 238, 7, 128] 33,024\n", + "| └─Conv2d: 2-15 [-1, 256, 4, 119] 1,147,136\n", + "| └─ReLU: 2-16 [-1, 256, 4, 119] --\n", + "| └─Dropout: 2-17 [-1, 256, 4, 119] --\n", + "| └─InstanceNorm2d: 2-18 [-1, 256, 4, 119] 512\n", + "| └─TDSBlock2d: 2-19 [-1, 256, 4, 119] --\n", + "| | └─Sequential: 3-13 [-1, 64, 4, 4, 119] 143,424\n", + "| | └─Sequential: 3-14 [-1, 119, 4, 256] 131,584\n", + "| └─TDSBlock2d: 2-20 [-1, 256, 4, 119] --\n", + "| | └─Sequential: 3-15 [-1, 64, 4, 4, 119] 143,424\n", + "| | └─Sequential: 3-16 [-1, 119, 4, 256] 131,584\n", + "| └─TDSBlock2d: 2-21 [-1, 256, 4, 119] --\n", + "| | └─Sequential: 3-17 [-1, 64, 4, 4, 119] 143,424\n", + "| | └─Sequential: 3-18 [-1, 119, 4, 256] 131,584\n", + "| └─Conv2d: 2-22 [-1, 512, 2, 119] 4,588,032\n", + "| └─ReLU: 2-23 [-1, 512, 2, 119] --\n", + "| └─Dropout: 2-24 [-1, 512, 2, 119] --\n", + "| └─InstanceNorm2d: 2-25 [-1, 512, 2, 119] 1,024\n", + "| └─TDSBlock2d: 2-26 [-1, 512, 2, 119] --\n", + "| | └─Sequential: 3-19 [-1, 128, 4, 2, 119] 573,568\n", + "| | └─Sequential: 3-20 [-1, 119, 2, 512] 525,312\n", + "| └─TDSBlock2d: 2-27 [-1, 512, 2, 119] --\n", + "| | └─Sequential: 3-21 [-1, 128, 4, 2, 119] 573,568\n", + "| | └─Sequential: 3-22 [-1, 119, 2, 512] 525,312\n", + "| └─TDSBlock2d: 2-28 [-1, 512, 2, 119] --\n", + "| | └─Sequential: 3-23 [-1, 128, 4, 2, 119] 573,568\n", + "| | └─Sequential: 3-24 [-1, 119, 2, 512] 525,312\n", + "├─Linear: 1-2 [-1, 119, 128] 131,200\n", + "===============================================================================================\n", + "Total params: 10,272,252\n", + "Trainable params: 10,272,252\n", + "Non-trainable params: 0\n", + "Total mult-adds (G): 5.00\n", + "===============================================================================================\n", + "Input size (MB): 0.10\n", + "Forward/backward pass size (MB): 73.21\n", + "Params size (MB): 39.19\n", + "Estimated Total Size (MB): 112.50\n", + "===============================================================================================" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "summary(tds2d, (1, 28, 952), device=\"cpu\", depth=3)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "t = torch.randn(2,1, 28, 952)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 119, 128])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tds2d(t).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "cnn = CNN().cuda()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "i = nn.Sequential(nn.Conv2d(1,1,1,1))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nn.Sequential(i,i)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "cnn(t).shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks.vqvae import Encoder, Decoder, VQVAE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vqvae = VQVAE(1, [32, 128, 128, 256], [4, 4, 4, 4], [2, 2, [1, 2], [1, 2]], 2, 32, 256, [[6, 119], [7, 238]])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "t = torch.randn(2, 1, 28, 952)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x, l = vqvae(t)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "5 * 59 / 10" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "x.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "summary(vqvae, (1, 28, 952), device=\"cpu\", depth=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "up = nn.Upsample([4, 59])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "up(tt).shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "tt.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class GEGLU(nn.Module):\n", + " def __init__(self, dim_in, dim_out):\n", + " super().__init__()\n", + " self.proj = nn.Linear(dim_in, dim_out * 2)\n", + "\n", + " def forward(self, x):\n", + " x, gate = self.proj(x).chunk(2, dim = -1)\n", + " return x * F.gelu(gate)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "e = GEGLU(256, 2048)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "e(t).shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "emb = nn.Embedding(56, 256)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with torch.no_grad():\n", + " e = emb(torch.Tensor([55]).long())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from einops import repeat" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ee = repeat(e, \"() n -> b n\", b=16)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "emb.device" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ee" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ee.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "t = torch.randn(16, 10, 256)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "t.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "t = torch.cat((ee.unsqueeze(1), t, ee.unsqueeze(1)), dim=1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "t.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "e.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks.residual_network import IdentityBlock, ResidualBlock, BasicBlock, BottleNeckBlock, ResidualLayer, ResidualNetwork, ResidualNetworkEncoder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks import WideResidualNetwork" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "wr = WideResidualNetwork(\n", + " in_channels= 1,\n", + " num_classes= 80,\n", + " in_planes=64,\n", + " depth=10,\n", + " num_layers=4,\n", + " width_factor=2,\n", + " num_stages=[64, 128, 256, 256],\n", + " dropout_rate= 0.1,\n", + " activation= \"SELU\",\n", + " use_decoder= False,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from torchsummary import summary" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "backbone = ResidualNetworkEncoder(1, [64, 65, 66, 67, 68], [2, 2, 2, 2, 2])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "summary(backbone, (1, 28, 952), device=\"cpu\", depth=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + " backbone = nn.Sequential(\n", + " *list(wr.children())[:][:]\n", + " )\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "backbone" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "summary(wr, (1, 28, 952), device=\"cpu\", depth=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "a = torch.rand(1, 1, 28, 952)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "b = wr(a)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from einops import rearrange" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "b = rearrange(b, \"b c h w -> b w c h\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "c = nn.AdaptiveAvgPool2d((None, 1))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "d = c(b)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "d.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "d.squeeze(3).shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "b.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from torch import nn" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "32 + 64" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "3 * 112" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "col_embed = nn.Parameter(torch.rand(1000, 256 // 2))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "W, H = 196, 4" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "col_embed[:W].unsqueeze(0).repeat(H, 1, 1).shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "col_embed[:H].unsqueeze(1).repeat(1, W, 1).shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + " torch.cat(\n", + " [\n", + " col_embed[:W].unsqueeze(0).repeat(H, 1, 1),\n", + " col_embed[:H].unsqueeze(1).repeat(1, W, 1),\n", + " ],\n", + " dim=-1,\n", + " ).unsqueeze(0).shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "4 * 196" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "target = torch.tensor([1,1,12,1,1,1,1,1,9,9,9,9,9,9])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "torch.nonzero(target == 9, as_tuple=False)[0].item()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "target[:9]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "np.inf" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks.transformer.positional_encoding import PositionalEncoding" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(figsize=(15, 5))\n", + "pe = PositionalEncoding(20, 0)\n", + "y = pe.forward(torch.zeros(1, 100, 20))\n", + "plt.plot(np.arange(100), y[0, :, 4:8].data.numpy())\n", + "plt.legend([\"dim %d\"%p for p in [4,5,6,7]])\n", + "None" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks.densenet import DenseNet,_DenseLayer,_DenseBlock" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dnet = DenseNet(12, (6, 12, 10), 1, 24, 80, 4, 0, True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "216 / 8" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "summary(dnet, (1, 28, 952), device=\"cpu\", depth=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + " backbone = nn.Sequential(\n", + " *list(dnet.children())[:][:-4]\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "backbone" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks import WideResidualNetwork" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "w = WideResidualNetwork(\n", + " in_channels = 1,\n", + " in_planes = 32,\n", + " num_classes = 80,\n", + " depth = 10,\n", + " width_factor = 1,\n", + " dropout_rate = 0.0,\n", + " num_layers = 5,\n", + " activation = \"relu\",\n", + " use_decoder = False,)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "summary(w, (1, 28, 952), device=\"cpu\", depth=2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "sz= 5" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mask = torch.triu(torch.ones(sz, sz), 1)\n", + "mask = mask.masked_fill(mask==1, float('-inf'))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "h = torch.rand(1, 256, 10, 10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "h.flatten(2).permute(2, 0, 1).shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "h.flatten(2).permute(2, 0, 1).shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mask\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pred = torch.Tensor([1,21,2,45,31, 81, 1, 79, 79, 79, 2,1,1,1,1, 81, 1, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()\n", + "target = torch.Tensor([1,1,1,1,1, 81, 79, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mask = (target != 79)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "mask" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pred * mask" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "target * mask" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.models.metrics import accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pad_indcies = torch.nonzero(target == 79, as_tuple=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "t1 = torch.nonzero(target == 81, as_tuple=False).squeeze(1)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "target.shape[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "t2 = torch.arange(10, target.shape[0] + 1, 10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "t2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for start, stop in zip(t1, t2):\n", + " pred[start+1:stop] = 79" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pred" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "[pred[start+1:stop] = 79 for start, stop in zip(t1, t2)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "pad_indcies" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pred[pad_indcies:pad_indcies] = 79" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "pred.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "target.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "accuracy(pred, target)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "acc = (pred == target).sum().float() / target.shape[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "acc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/src/notebooks/01-look-at-emnist.ipynb b/notebooks/01-look-at-emnist.ipynb index b70ce12..b70ce12 100644 --- a/src/notebooks/01-look-at-emnist.ipynb +++ b/notebooks/01-look-at-emnist.ipynb diff --git a/src/notebooks/02a-sentence-generator.ipynb b/notebooks/02a-sentence-generator.ipynb index 99aa56a..99aa56a 100644 --- a/src/notebooks/02a-sentence-generator.ipynb +++ b/notebooks/02a-sentence-generator.ipynb diff --git a/src/notebooks/02b-emnist-lines-dataset.ipynb b/notebooks/02b-emnist-lines-dataset.ipynb index f82342b..f82342b 100644 --- a/src/notebooks/02b-emnist-lines-dataset.ipynb +++ b/notebooks/02b-emnist-lines-dataset.ipynb diff --git a/src/notebooks/02c-image-patches.ipynb b/notebooks/02c-image-patches.ipynb index fedea91..fedea91 100644 --- a/src/notebooks/02c-image-patches.ipynb +++ b/notebooks/02c-image-patches.ipynb diff --git a/src/notebooks/03a-line-prediction.ipynb b/notebooks/03a-line-prediction.ipynb index 13f4ff1..13f4ff1 100644 --- a/src/notebooks/03a-line-prediction.ipynb +++ b/notebooks/03a-line-prediction.ipynb diff --git a/src/notebooks/04a-look-at-iam-lines.ipynb b/notebooks/04a-look-at-iam-lines.ipynb index de59a85..de59a85 100644 --- a/src/notebooks/04a-look-at-iam-lines.ipynb +++ b/notebooks/04a-look-at-iam-lines.ipynb diff --git a/src/notebooks/04b-look-at-iam-paragraphs-predictions.ipynb b/notebooks/04b-look-at-iam-paragraphs-predictions.ipynb index 5662eb1..5662eb1 100644 --- a/src/notebooks/04b-look-at-iam-paragraphs-predictions.ipynb +++ b/notebooks/04b-look-at-iam-paragraphs-predictions.ipynb diff --git a/src/notebooks/04b-look-at-iam-paragraphs.ipynb b/notebooks/04b-look-at-iam-paragraphs.ipynb index dc0aef6..dc0aef6 100644 --- a/src/notebooks/04b-look-at-iam-paragraphs.ipynb +++ b/notebooks/04b-look-at-iam-paragraphs.ipynb diff --git a/src/notebooks/05-sanity-check-multihead-attention.ipynb b/notebooks/05-sanity-check-multihead-attention.ipynb index 54f0432..54f0432 100644 --- a/src/notebooks/05-sanity-check-multihead-attention.ipynb +++ b/notebooks/05-sanity-check-multihead-attention.ipynb diff --git a/src/notebooks/05a-UNet.ipynb b/notebooks/05a-UNet.ipynb index 77d895d..77d895d 100644 --- a/src/notebooks/05a-UNet.ipynb +++ b/notebooks/05a-UNet.ipynb diff --git a/src/notebooks/05a-test-end-to-end-model.ipynb b/notebooks/05a-test-end-to-end-model.ipynb index 7723b12..7723b12 100644 --- a/src/notebooks/05a-test-end-to-end-model.ipynb +++ b/notebooks/05a-test-end-to-end-model.ipynb diff --git a/src/notebooks/06-try-transformer-model-predictions.ipynb b/notebooks/06-try-transformer-model-predictions.ipynb index d39e111..d39e111 100644 --- a/src/notebooks/06-try-transformer-model-predictions.ipynb +++ b/notebooks/06-try-transformer-model-predictions.ipynb diff --git a/src/notebooks/07-look-at-lexicon.ipynb b/notebooks/07-look-at-lexicon.ipynb index b7a5a0e..b7a5a0e 100644 --- a/src/notebooks/07-look-at-lexicon.ipynb +++ b/notebooks/07-look-at-lexicon.ipynb diff --git a/src/notebooks/07-try-gtn.ipynb b/notebooks/07-try-gtn.ipynb index 4ef444b..4ef444b 100644 --- a/src/notebooks/07-try-gtn.ipynb +++ b/notebooks/07-try-gtn.ipynb diff --git a/src/notebooks/Untitled.ipynb b/notebooks/Untitled.ipynb index 841a37d..841a37d 100644 --- a/src/notebooks/Untitled.ipynb +++ b/notebooks/Untitled.ipynb diff --git a/src/notebooks/g1.png b/notebooks/g1.png Binary files differindex 09dd49e..09dd49e 100644 --- a/src/notebooks/g1.png +++ b/notebooks/g1.png diff --git a/src/notebooks/g2.png b/notebooks/g2.png Binary files differindex a3cf21e..a3cf21e 100644 --- a/src/notebooks/g2.png +++ b/notebooks/g2.png diff --git a/src/notebooks/intersect.png b/notebooks/intersect.png Binary files differindex 63b7f2f..63b7f2f 100644 --- a/src/notebooks/intersect.png +++ b/notebooks/intersect.png diff --git a/src/notebooks/intersection.pdf b/notebooks/intersection.pdf Binary files differindex c425a9f..c425a9f 100644 --- a/src/notebooks/intersection.pdf +++ b/notebooks/intersection.pdf @@ -40,7 +40,7 @@ def install_with_constraints(session: Session, *args: str, **kwargs: Any) -> Non session.install(f"--constraint={requirements.name}", *args, **kwargs) -@nox.session(python="3.8") +@nox.session(python="3.9") def black(session: Session) -> None: """Run black code formatter.""" args = session.posargs or locations @@ -48,7 +48,7 @@ def black(session: Session) -> None: session.run("black", *args) -@nox.session(python=["3.8"]) +@nox.session(python=["3.9"]) def lint(session: Session) -> None: """Lint using flake8.""" args = session.posargs or locations @@ -66,7 +66,7 @@ def lint(session: Session) -> None: session.run("flake8", *args) -@nox.session(python="3.8") +@nox.session(python="3.9") def safety(session: Session) -> None: """Scan dependencies for insecure packages.""" with tempfile.NamedTemporaryFile() as requirements: @@ -83,7 +83,7 @@ def safety(session: Session) -> None: session.run("safety", "check", f"--file={requirements.name}", "--full-report") -@nox.session(python=["3.8"]) +@nox.session(python=["3.9"]) def mypy(session: Session) -> None: """Type-check using mypy.""" args = session.posargs or locations @@ -91,7 +91,7 @@ def mypy(session: Session) -> None: session.run("mypy", *args) -@nox.session(python="3.8") +@nox.session(python="3.9") def pytype(session: Session) -> None: """Type-check using pytype.""" args = session.posargs or ["--disable=import-error", *locations] @@ -99,7 +99,7 @@ def pytype(session: Session) -> None: session.run("pytype", *args) -@nox.session(python=["3.8"]) +@nox.session(python=["3.9"]) def tests(session: Session) -> None: """Run the test suite.""" args = session.posargs or ["--cov", "-m", "not e2e"] @@ -110,7 +110,7 @@ def tests(session: Session) -> None: session.run("pytest", *args) -@nox.session(python=["3.8"]) +@nox.session(python=["3.9"]) def typeguard(session: Session) -> None: """Runtime type checking using Typeguard.""" args = session.posargs or ["-m", "not e2e"] @@ -119,7 +119,7 @@ def typeguard(session: Session) -> None: session.run("pytest", f"--typeguard-packages={package}", *args) -@nox.session(python=["3.8"]) +@nox.session(python=["3.9"]) def xdoctest(session: Session) -> None: """Run examples with xdoctest.""" args = session.posargs or ["all"] @@ -128,7 +128,7 @@ def xdoctest(session: Session) -> None: session.run("python", "-m", "xdoctest", package, *args) -@nox.session(python="3.8") +@nox.session(python="3.9") def coverage(session: Session) -> None: """Upload coverage data.""" install_with_constraints(session, "coverage[toml]", "codecov") @@ -136,7 +136,7 @@ def coverage(session: Session) -> None: session.run("codecov", *session.posargs) -@nox.session(python="3.8") +@nox.session(python="3.9") def docs(session: Session) -> None: """Build the documentation.""" session.run("poetry", "install", "--no-dev", external=True) diff --git a/poetry.lock b/poetry.lock index 72da168..78f086e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -244,14 +244,6 @@ optional = false python-versions = ">=3.6,<4.0" [[package]] -name = "dataclasses" -version = "0.6" -description = "A backport of the dataclasses module for Python 3.6" -category = "main" -optional = false -python-versions = "*" - -[[package]] name = "decorator" version = "4.4.2" description = "Decorators for Humans" @@ -432,14 +424,6 @@ python-versions = "*" flake8 = "*" [[package]] -name = "future" -version = "0.18.2" -description = "Clean single-source support for Python 3 and 2" -category = "main" -optional = false -python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" - -[[package]] name = "gitdb" version = "4.0.5" description = "Git Object Database" @@ -501,15 +485,17 @@ python-versions = ">=3.5" [[package]] name = "h5py" -version = "2.10.0" +version = "3.2.1" description = "Read and write HDF5 files from Python" category = "main" optional = false -python-versions = "*" +python-versions = ">=3.7" [package.dependencies] -numpy = ">=1.7" -six = "*" +numpy = [ + {version = ">=1.17.5", markers = "python_version == \"3.8\""}, + {version = ">=1.19.3", markers = "python_version >= \"3.9\""}, +] [[package]] name = "idna" @@ -995,11 +981,11 @@ test = ["nose", "coverage", "requests", "nose-warnings-filters", "nbval", "nose- [[package]] name = "numpy" -version = "1.19.4" +version = "1.20.1" description = "NumPy is the fundamental package for array computing with Python." category = "main" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" [[package]] name = "nvidia-ml-py3" @@ -1029,6 +1015,9 @@ category = "main" optional = false python-versions = ">=3.6" +[package.dependencies] +numpy = ">=1.19.3" + [[package]] name = "packaging" version = "20.4" @@ -1782,15 +1771,13 @@ python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" [[package]] name = "torch" -version = "1.7.0" +version = "1.7.1" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" category = "main" optional = false -python-versions = ">=3.6.1" +python-versions = ">=3.6.2" [package.dependencies] -dataclasses = "*" -future = "*" numpy = "*" typing-extensions = "*" @@ -1804,7 +1791,7 @@ python-versions = ">=3.5" [[package]] name = "torchvision" -version = "0.8.1" +version = "0.8.2" description = "image and video datasets and models for torch deep learning" category = "main" optional = false @@ -1813,7 +1800,7 @@ python-versions = "*" [package.dependencies] numpy = "*" pillow = ">=4.1.1" -torch = "1.7.0" +torch = "1.7.1" [package.extras] scipy = ["scipy"] @@ -2006,7 +1993,7 @@ tests = ["pytest", "pytest-cov", "codecov", "scikit-build", "cmake", "ninja", "p [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "1f194d7de179e9676ef1f8e51b83ff15c001627803008ef8225e8e14ab3acab0" +content-hash = "c87742a388e1277e84313b4c0ff75681d754c8328db2c488c0aba2a4dafc6a64" [metadata.files] alabaster = [ @@ -2038,6 +2025,8 @@ argon2-cffi = [ {file = "argon2_cffi-20.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:6678bb047373f52bcff02db8afab0d2a77d83bde61cfecea7c5c62e2335cb203"}, {file = "argon2_cffi-20.1.0-cp38-cp38-win32.whl", hash = "sha256:77e909cc756ef81d6abb60524d259d959bab384832f0c651ed7dcb6e5ccdbb78"}, {file = "argon2_cffi-20.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:9dfd5197852530294ecb5795c97a823839258dfd5eb9420233c7cfedec2058f2"}, + {file = "argon2_cffi-20.1.0-cp39-cp39-win32.whl", hash = "sha256:e2db6e85c057c16d0bd3b4d2b04f270a7467c147381e8fd73cbbe5bc719832be"}, + {file = "argon2_cffi-20.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:8a84934bd818e14a17943de8099d41160da4a336bcc699bb4c394bbb9b94bd32"}, ] async-generator = [ {file = "async_generator-1.10-py3-none-any.whl", hash = "sha256:01c7bf666359b4967d2cda0000cc2e4af16a0ae098cbffcb8472fb9e8ad6585b"}, @@ -2182,10 +2171,6 @@ darglint = [ {file = "darglint-1.5.6-py3-none-any.whl", hash = "sha256:6fcef385e646c4da9ea6fc547e28c77a33ae0cba4806b8585ae18a490a797e82"}, {file = "darglint-1.5.6.tar.gz", hash = "sha256:98acb4064bae73ec02146cb123dd3c930bd5272e562ad4d19c59857443632dd1"}, ] -dataclasses = [ - {file = "dataclasses-0.6-py3-none-any.whl", hash = "sha256:454a69d788c7fda44efd71e259be79577822f5e3f53f029a22d08004e951dc9f"}, - {file = "dataclasses-0.6.tar.gz", hash = "sha256:6988bd2b895eef432d562370bb707d540f32f7360ab13da45340101bc2307d84"}, -] decorator = [ {file = "decorator-4.4.2-py2.py3-none-any.whl", hash = "sha256:41fa54c2a0cc4ba648be4fd43cff00aedf5b9465c9bf18d64325bc225f08f760"}, {file = "decorator-4.4.2.tar.gz", hash = "sha256:e3a62f0520172440ca0dcc823749319382e377f37f140a0b99ef45fecb84bfe7"}, @@ -2248,9 +2233,6 @@ flake8-polyfill = [ {file = "flake8-polyfill-1.0.2.tar.gz", hash = "sha256:e44b087597f6da52ec6393a709e7108b2905317d0c0b744cdca6208e670d8eda"}, {file = "flake8_polyfill-1.0.2-py2.py3-none-any.whl", hash = "sha256:12be6a34ee3ab795b19ca73505e7b55826d5f6ad7230d31b18e106400169b9e9"}, ] -future = [ - {file = "future-0.18.2.tar.gz", hash = "sha256:b1bead90b70cf6ec3f0710ae53a525360fa360d306a86583adc6bf83a4db537d"}, -] gitdb = [ {file = "gitdb-4.0.5-py3-none-any.whl", hash = "sha256:91f36bfb1ab7949b3b40e23736db18231bf7593edada2ba5c3a174a7b23657ac"}, {file = "gitdb-4.0.5.tar.gz", hash = "sha256:c9e1f2d0db7ddb9a704c2a0217be31214e91a4fe1dea1efad19ae42ba0c285c9"}, @@ -2270,35 +2252,16 @@ gtn = [ {file = "gtn-0.0.0.tar.gz", hash = "sha256:72fece9ca51df161c1274e570d6f5f933e76f4cac9d8d6dd543a3fe0383f7268"}, ] h5py = [ - {file = "h5py-2.10.0-cp27-cp27m-macosx_10_6_intel.whl", hash = "sha256:ecf4d0b56ee394a0984de15bceeb97cbe1fe485f1ac205121293fc44dcf3f31f"}, - {file = "h5py-2.10.0-cp27-cp27m-manylinux1_i686.whl", hash = "sha256:86868dc07b9cc8cb7627372a2e6636cdc7a53b7e2854ad020c9e9d8a4d3fd0f5"}, - {file = "h5py-2.10.0-cp27-cp27m-manylinux1_x86_64.whl", hash = "sha256:aac4b57097ac29089f179bbc2a6e14102dd210618e94d77ee4831c65f82f17c0"}, - {file = "h5py-2.10.0-cp27-cp27m-win32.whl", hash = "sha256:7be5754a159236e95bd196419485343e2b5875e806fe68919e087b6351f40a70"}, - {file = "h5py-2.10.0-cp27-cp27m-win_amd64.whl", hash = "sha256:13c87efa24768a5e24e360a40e0bc4c49bcb7ce1bb13a3a7f9902cec302ccd36"}, - {file = "h5py-2.10.0-cp27-cp27mu-manylinux1_i686.whl", hash = "sha256:79b23f47c6524d61f899254f5cd5e486e19868f1823298bc0c29d345c2447172"}, - {file = "h5py-2.10.0-cp27-cp27mu-manylinux1_x86_64.whl", hash = "sha256:cbf28ae4b5af0f05aa6e7551cee304f1d317dbed1eb7ac1d827cee2f1ef97a99"}, - {file = "h5py-2.10.0-cp34-cp34m-manylinux1_i686.whl", hash = "sha256:c0d4b04bbf96c47b6d360cd06939e72def512b20a18a8547fa4af810258355d5"}, - {file = "h5py-2.10.0-cp34-cp34m-manylinux1_x86_64.whl", hash = "sha256:549ad124df27c056b2e255ea1c44d30fb7a17d17676d03096ad5cd85edb32dc1"}, - {file = "h5py-2.10.0-cp35-cp35m-macosx_10_6_intel.whl", hash = "sha256:a5f82cd4938ff8761d9760af3274acf55afc3c91c649c50ab18fcff5510a14a5"}, - {file = "h5py-2.10.0-cp35-cp35m-manylinux1_i686.whl", hash = "sha256:3dad1730b6470fad853ef56d755d06bb916ee68a3d8272b3bab0c1ddf83bb99e"}, - {file = "h5py-2.10.0-cp35-cp35m-manylinux1_x86_64.whl", hash = "sha256:063947eaed5f271679ed4ffa36bb96f57bc14f44dd4336a827d9a02702e6ce6b"}, - {file = "h5py-2.10.0-cp35-cp35m-win32.whl", hash = "sha256:c54a2c0dd4957776ace7f95879d81582298c5daf89e77fb8bee7378f132951de"}, - {file = "h5py-2.10.0-cp35-cp35m-win_amd64.whl", hash = "sha256:6998be619c695910cb0effe5eb15d3a511d3d1a5d217d4bd0bebad1151ec2262"}, - {file = "h5py-2.10.0-cp36-cp36m-macosx_10_6_intel.whl", hash = "sha256:ff7d241f866b718e4584fa95f520cb19405220c501bd3a53ee11871ba5166ea2"}, - {file = "h5py-2.10.0-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:54817b696e87eb9e403e42643305f142cd8b940fe9b3b490bbf98c3b8a894cf4"}, - {file = "h5py-2.10.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:d3c59549f90a891691991c17f8e58c8544060fdf3ccdea267100fa5f561ff62f"}, - {file = "h5py-2.10.0-cp36-cp36m-win32.whl", hash = "sha256:d7ae7a0576b06cb8e8a1c265a8bc4b73d05fdee6429bffc9a26a6eb531e79d72"}, - {file = "h5py-2.10.0-cp36-cp36m-win_amd64.whl", hash = "sha256:bffbc48331b4a801d2f4b7dac8a72609f0b10e6e516e5c480a3e3241e091c878"}, - {file = "h5py-2.10.0-cp37-cp37m-macosx_10_6_intel.whl", hash = "sha256:51ae56894c6c93159086ffa2c94b5b3388c0400548ab26555c143e7cfa05b8e5"}, - {file = "h5py-2.10.0-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:16ead3c57141101e3296ebeed79c9c143c32bdd0e82a61a2fc67e8e6d493e9d1"}, - {file = "h5py-2.10.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:f0e25bb91e7a02efccb50aba6591d3fe2c725479e34769802fcdd4076abfa917"}, - {file = "h5py-2.10.0-cp37-cp37m-win32.whl", hash = "sha256:f23951a53d18398ef1344c186fb04b26163ca6ce449ebd23404b153fd111ded9"}, - {file = "h5py-2.10.0-cp37-cp37m-win_amd64.whl", hash = "sha256:8bb1d2de101f39743f91512a9750fb6c351c032e5cd3204b4487383e34da7f75"}, - {file = "h5py-2.10.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:64f74da4a1dd0d2042e7d04cf8294e04ddad686f8eba9bb79e517ae582f6668d"}, - {file = "h5py-2.10.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:d35f7a3a6cefec82bfdad2785e78359a0e6a5fbb3f605dd5623ce88082ccd681"}, - {file = "h5py-2.10.0-cp38-cp38-win32.whl", hash = "sha256:6ef7ab1089e3ef53ca099038f3c0a94d03e3560e6aff0e9d6c64c55fb13fc681"}, - {file = "h5py-2.10.0-cp38-cp38-win_amd64.whl", hash = "sha256:769e141512b54dee14ec76ed354fcacfc7d97fea5a7646b709f7400cf1838630"}, - {file = "h5py-2.10.0.tar.gz", hash = "sha256:84412798925dc870ffd7107f045d7659e60f5d46d1c70c700375248bf6bf512d"}, + {file = "h5py-3.2.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6766104ed13ff40b3b7bfd49f13fced5274103ee9af53667e7a97c5236b14741"}, + {file = "h5py-3.2.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:4160cb0d35a83c6fb9f1cad65e826dfaeb044e001549ea78003573fb6bee4042"}, + {file = "h5py-3.2.1-cp37-cp37m-win_amd64.whl", hash = "sha256:fdabe99139a9c5e1a416b7ed38c89505f8501b376d54496e1bb737cb33df61cf"}, + {file = "h5py-3.2.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d8467fa56356ad2efad2b5986326e71d4d74505de6f6c7bb46dbba09b37459ac"}, + {file = "h5py-3.2.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:a6632ac11167bbad1a8fc5c82508b97ab8c12bdfe4b659254b6f7f63d3c76744"}, + {file = "h5py-3.2.1-cp38-cp38-win_amd64.whl", hash = "sha256:90ee8a00aca5c4e0bbd821c1f6118cb9a814c15dcfdb03572c615a4431166480"}, + {file = "h5py-3.2.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:25294f2690c4813475f566663a21ef1c1b11ef892b26d46454bf0a59e507d5aa"}, + {file = "h5py-3.2.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:d791b710d3e54c4d2c32cb881b183db5674ceb03bf6a0c1f3fb3cf50d8997e0a"}, + {file = "h5py-3.2.1-cp39-cp39-win_amd64.whl", hash = "sha256:7c5b5f18c96fb63399280a724734fd91e1781c6b60e385e439ad8e654a294ba4"}, + {file = "h5py-3.2.1.tar.gz", hash = "sha256:89474be911bfcdb34cbf0d98b8ec48b578c27a89fdb1ae4ee7513f1ef8d9249e"}, ] idna = [ {file = "idna-2.10-py2.py3-none-any.whl", hash = "sha256:b97d804b1e9b523befed77c48dacec60e6dcb0b5391d57af6a65a312a90648c0"}, @@ -2426,20 +2389,39 @@ markupsafe = [ {file = "MarkupSafe-1.1.1-cp35-cp35m-win32.whl", hash = "sha256:6dd73240d2af64df90aa7c4e7481e23825ea70af4b4922f8ede5b9e35f78a3b1"}, {file = "MarkupSafe-1.1.1-cp35-cp35m-win_amd64.whl", hash = "sha256:9add70b36c5666a2ed02b43b335fe19002ee5235efd4b8a89bfcf9005bebac0d"}, {file = "MarkupSafe-1.1.1-cp36-cp36m-macosx_10_6_intel.whl", hash = "sha256:24982cc2533820871eba85ba648cd53d8623687ff11cbb805be4ff7b4c971aff"}, + {file = "MarkupSafe-1.1.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:d53bc011414228441014aa71dbec320c66468c1030aae3a6e29778a3382d96e5"}, {file = "MarkupSafe-1.1.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:00bc623926325b26bb9605ae9eae8a215691f33cae5df11ca5424f06f2d1f473"}, {file = "MarkupSafe-1.1.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:717ba8fe3ae9cc0006d7c451f0bb265ee07739daf76355d06366154ee68d221e"}, + {file = "MarkupSafe-1.1.1-cp36-cp36m-manylinux2010_i686.whl", hash = "sha256:3b8a6499709d29c2e2399569d96719a1b21dcd94410a586a18526b143ec8470f"}, + {file = "MarkupSafe-1.1.1-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:84dee80c15f1b560d55bcfe6d47b27d070b4681c699c572af2e3c7cc90a3b8e0"}, + {file = "MarkupSafe-1.1.1-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:b1dba4527182c95a0db8b6060cc98ac49b9e2f5e64320e2b56e47cb2831978c7"}, {file = "MarkupSafe-1.1.1-cp36-cp36m-win32.whl", hash = "sha256:535f6fc4d397c1563d08b88e485c3496cf5784e927af890fb3c3aac7f933ec66"}, {file = "MarkupSafe-1.1.1-cp36-cp36m-win_amd64.whl", hash = "sha256:b1282f8c00509d99fef04d8ba936b156d419be841854fe901d8ae224c59f0be5"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-macosx_10_6_intel.whl", hash = "sha256:8defac2f2ccd6805ebf65f5eeb132adcf2ab57aa11fdf4c0dd5169a004710e7d"}, + {file = "MarkupSafe-1.1.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:bf5aa3cbcfdf57fa2ee9cd1822c862ef23037f5c832ad09cfea57fa846dec193"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:46c99d2de99945ec5cb54f23c8cd5689f6d7177305ebff350a58ce5f8de1669e"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:ba59edeaa2fc6114428f1637ffff42da1e311e29382d81b339c1817d37ec93c6"}, + {file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:6fffc775d90dcc9aed1b89219549b329a9250d918fd0b8fa8d93d154918422e1"}, + {file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:a6a744282b7718a2a62d2ed9d993cad6f5f585605ad352c11de459f4108df0a1"}, + {file = "MarkupSafe-1.1.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:195d7d2c4fbb0ee8139a6cf67194f3973a6b3042d742ebe0a9ed36d8b6f0c07f"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-win32.whl", hash = "sha256:b00c1de48212e4cc9603895652c5c410df699856a2853135b3967591e4beebc2"}, {file = "MarkupSafe-1.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9bf40443012702a1d2070043cb6291650a0841ece432556f784f004937f0f32c"}, {file = "MarkupSafe-1.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6788b695d50a51edb699cb55e35487e430fa21f1ed838122d722e0ff0ac5ba15"}, {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:cdb132fc825c38e1aeec2c8aa9338310d29d337bebbd7baa06889d09a60a1fa2"}, {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:13d3144e1e340870b25e7b10b98d779608c02016d5184cfb9927a9f10c689f42"}, + {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:acf08ac40292838b3cbbb06cfe9b2cb9ec78fce8baca31ddb87aaac2e2dc3bc2"}, + {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:d9be0ba6c527163cbed5e0857c451fcd092ce83947944d6c14bc95441203f032"}, + {file = "MarkupSafe-1.1.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:caabedc8323f1e93231b52fc32bdcde6db817623d33e100708d9a68e1f53b26b"}, {file = "MarkupSafe-1.1.1-cp38-cp38-win32.whl", hash = "sha256:596510de112c685489095da617b5bcbbac7dd6384aeebeda4df6025d0256a81b"}, {file = "MarkupSafe-1.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:e8313f01ba26fbbe36c7be1966a7b7424942f670f38e666995b88d012765b9be"}, + {file = "MarkupSafe-1.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d73a845f227b0bfe8a7455ee623525ee656a9e2e749e4742706d80a6065d5e2c"}, + {file = "MarkupSafe-1.1.1-cp39-cp39-manylinux1_i686.whl", hash = "sha256:98bae9582248d6cf62321dcb52aaf5d9adf0bad3b40582925ef7c7f0ed85fceb"}, + {file = "MarkupSafe-1.1.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:2beec1e0de6924ea551859edb9e7679da6e4870d32cb766240ce17e0a0ba2014"}, + {file = "MarkupSafe-1.1.1-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:7fed13866cf14bba33e7176717346713881f56d9d2bcebab207f7a036f41b850"}, + {file = "MarkupSafe-1.1.1-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:6f1e273a344928347c1290119b493a1f0303c52f5a5eae5f16d74f48c15d4a85"}, + {file = "MarkupSafe-1.1.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:feb7b34d6325451ef96bc0e36e1a6c0c1c64bc1fbec4b854f4529e51887b1621"}, + {file = "MarkupSafe-1.1.1-cp39-cp39-win32.whl", hash = "sha256:22c178a091fc6630d0d045bdb5992d2dfe14e3259760e713c490da5323866c39"}, + {file = "MarkupSafe-1.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:b7d644ddb4dbd407d31ffb699f1d140bc35478da613b441c582aeb7c43838dd8"}, {file = "MarkupSafe-1.1.1.tar.gz", hash = "sha256:29872e92839765e546828bb7754a68c418d927cd064fd4708fab9fe9c8bb116b"}, ] marshmallow = [ @@ -2529,40 +2511,30 @@ notebook = [ {file = "notebook-6.1.5.tar.gz", hash = "sha256:3db37ae834c5f3b6378381229d0e5dfcbfb558d08c8ce646b1ad355147f5e91d"}, ] numpy = [ - {file = "numpy-1.19.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:e9b30d4bd69498fc0c3fe9db5f62fffbb06b8eb9321f92cc970f2969be5e3949"}, - {file = "numpy-1.19.4-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:fedbd128668ead37f33917820b704784aff695e0019309ad446a6d0b065b57e4"}, - {file = "numpy-1.19.4-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:8ece138c3a16db8c1ad38f52eb32be6086cc72f403150a79336eb2045723a1ad"}, - {file = "numpy-1.19.4-cp36-cp36m-manylinux2010_i686.whl", hash = "sha256:64324f64f90a9e4ef732be0928be853eee378fd6a01be21a0a8469c4f2682c83"}, - {file = "numpy-1.19.4-cp36-cp36m-manylinux2010_x86_64.whl", hash = "sha256:ad6f2ff5b1989a4899bf89800a671d71b1612e5ff40866d1f4d8bcf48d4e5764"}, - {file = "numpy-1.19.4-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:d6c7bb82883680e168b55b49c70af29b84b84abb161cbac2800e8fcb6f2109b6"}, - {file = "numpy-1.19.4-cp36-cp36m-win32.whl", hash = "sha256:13d166f77d6dc02c0a73c1101dd87fdf01339febec1030bd810dcd53fff3b0f1"}, - {file = "numpy-1.19.4-cp36-cp36m-win_amd64.whl", hash = "sha256:448ebb1b3bf64c0267d6b09a7cba26b5ae61b6d2dbabff7c91b660c7eccf2bdb"}, - {file = "numpy-1.19.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:27d3f3b9e3406579a8af3a9f262f5339005dd25e0ecf3cf1559ff8a49ed5cbf2"}, - {file = "numpy-1.19.4-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:16c1b388cc31a9baa06d91a19366fb99ddbe1c7b205293ed072211ee5bac1ed2"}, - {file = "numpy-1.19.4-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:e5b6ed0f0b42317050c88022349d994fe72bfe35f5908617512cd8c8ef9da2a9"}, - {file = "numpy-1.19.4-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:18bed2bcb39e3f758296584337966e68d2d5ba6aab7e038688ad53c8f889f757"}, - {file = "numpy-1.19.4-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:fe45becb4c2f72a0907c1d0246ea6449fe7a9e2293bb0e11c4e9a32bb0930a15"}, - {file = "numpy-1.19.4-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:6d7593a705d662be5bfe24111af14763016765f43cb6923ed86223f965f52387"}, - {file = "numpy-1.19.4-cp37-cp37m-win32.whl", hash = "sha256:6ae6c680f3ebf1cf7ad1d7748868b39d9f900836df774c453c11c5440bc15b36"}, - {file = "numpy-1.19.4-cp37-cp37m-win_amd64.whl", hash = "sha256:9eeb7d1d04b117ac0d38719915ae169aa6b61fca227b0b7d198d43728f0c879c"}, - {file = "numpy-1.19.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:cb1017eec5257e9ac6209ac172058c430e834d5d2bc21961dceeb79d111e5909"}, - {file = "numpy-1.19.4-cp38-cp38-manylinux1_i686.whl", hash = "sha256:edb01671b3caae1ca00881686003d16c2209e07b7ef8b7639f1867852b948f7c"}, - {file = "numpy-1.19.4-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:f29454410db6ef8126c83bd3c968d143304633d45dc57b51252afbd79d700893"}, - {file = "numpy-1.19.4-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:ec149b90019852266fec2341ce1db513b843e496d5a8e8cdb5ced1923a92faab"}, - {file = "numpy-1.19.4-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:1aeef46a13e51931c0b1cf8ae1168b4a55ecd282e6688fdb0a948cc5a1d5afb9"}, - {file = "numpy-1.19.4-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:08308c38e44cc926bdfce99498b21eec1f848d24c302519e64203a8da99a97db"}, - {file = "numpy-1.19.4-cp38-cp38-win32.whl", hash = "sha256:5734bdc0342aba9dfc6f04920988140fb41234db42381cf7ccba64169f9fe7ac"}, - {file = "numpy-1.19.4-cp38-cp38-win_amd64.whl", hash = "sha256:09c12096d843b90eafd01ea1b3307e78ddd47a55855ad402b157b6c4862197ce"}, - {file = "numpy-1.19.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e452dc66e08a4ce642a961f134814258a082832c78c90351b75c41ad16f79f63"}, - {file = "numpy-1.19.4-cp39-cp39-manylinux1_i686.whl", hash = "sha256:a5d897c14513590a85774180be713f692df6fa8ecf6483e561a6d47309566f37"}, - {file = "numpy-1.19.4-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:a09f98011236a419ee3f49cedc9ef27d7a1651df07810ae430a6b06576e0b414"}, - {file = "numpy-1.19.4-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:50e86c076611212ca62e5a59f518edafe0c0730f7d9195fec718da1a5c2bb1fc"}, - {file = "numpy-1.19.4-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:f0d3929fe88ee1c155129ecd82f981b8856c5d97bcb0d5f23e9b4242e79d1de3"}, - {file = "numpy-1.19.4-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:c42c4b73121caf0ed6cd795512c9c09c52a7287b04d105d112068c1736d7c753"}, - {file = "numpy-1.19.4-cp39-cp39-win32.whl", hash = "sha256:8cac8790a6b1ddf88640a9267ee67b1aee7a57dfa2d2dd33999d080bc8ee3a0f"}, - {file = "numpy-1.19.4-cp39-cp39-win_amd64.whl", hash = "sha256:4377e10b874e653fe96985c05feed2225c912e328c8a26541f7fc600fb9c637b"}, - {file = "numpy-1.19.4-pp36-pypy36_pp73-manylinux2010_x86_64.whl", hash = "sha256:2a2740aa9733d2e5b2dfb33639d98a64c3b0f24765fed86b0fd2aec07f6a0a08"}, - {file = "numpy-1.19.4.zip", hash = "sha256:141ec3a3300ab89c7f2b0775289954d193cc8edb621ea05f99db9cb181530512"}, + {file = "numpy-1.20.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:ae61f02b84a0211abb56462a3b6cd1e7ec39d466d3160eb4e1da8bf6717cdbeb"}, + {file = "numpy-1.20.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:65410c7f4398a0047eea5cca9b74009ea61178efd78d1be9847fac1d6716ec1e"}, + {file = "numpy-1.20.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:2d7e27442599104ee08f4faed56bb87c55f8b10a5494ac2ead5c98a4b289e61f"}, + {file = "numpy-1.20.1-cp37-cp37m-manylinux2010_i686.whl", hash = "sha256:4ed8e96dc146e12c1c5cdd6fb9fd0757f2ba66048bf94c5126b7efebd12d0090"}, + {file = "numpy-1.20.1-cp37-cp37m-manylinux2010_x86_64.whl", hash = "sha256:ecb5b74c702358cdc21268ff4c37f7466357871f53a30e6f84c686952bef16a9"}, + {file = "numpy-1.20.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:b9410c0b6fed4a22554f072a86c361e417f0258838957b78bd063bde2c7f841f"}, + {file = "numpy-1.20.1-cp37-cp37m-win32.whl", hash = "sha256:3d3087e24e354c18fb35c454026af3ed8997cfd4997765266897c68d724e4845"}, + {file = "numpy-1.20.1-cp37-cp37m-win_amd64.whl", hash = "sha256:89f937b13b8dd17b0099c7c2e22066883c86ca1575a975f754babc8fbf8d69a9"}, + {file = "numpy-1.20.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:a1d7995d1023335e67fb070b2fae6f5968f5be3802b15ad6d79d81ecaa014fe0"}, + {file = "numpy-1.20.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:60759ab15c94dd0e1ed88241fd4fa3312db4e91d2c8f5a2d4cf3863fad83d65b"}, + {file = "numpy-1.20.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:125a0e10ddd99a874fd357bfa1b636cd58deb78ba4a30b5ddb09f645c3512e04"}, + {file = "numpy-1.20.1-cp38-cp38-manylinux2010_i686.whl", hash = "sha256:c26287dfc888cf1e65181f39ea75e11f42ffc4f4529e5bd19add57ad458996e2"}, + {file = "numpy-1.20.1-cp38-cp38-manylinux2010_x86_64.whl", hash = "sha256:7199109fa46277be503393be9250b983f325880766f847885607d9b13848f257"}, + {file = "numpy-1.20.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:72251e43ac426ff98ea802a931922c79b8d7596480300eb9f1b1e45e0543571e"}, + {file = "numpy-1.20.1-cp38-cp38-win32.whl", hash = "sha256:c91ec9569facd4757ade0888371eced2ecf49e7982ce5634cc2cf4e7331a4b14"}, + {file = "numpy-1.20.1-cp38-cp38-win_amd64.whl", hash = "sha256:13adf545732bb23a796914fe5f891a12bd74cf3d2986eed7b7eba2941eea1590"}, + {file = "numpy-1.20.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:104f5e90b143dbf298361a99ac1af4cf59131218a045ebf4ee5990b83cff5fab"}, + {file = "numpy-1.20.1-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:89e5336f2bec0c726ac7e7cdae181b325a9c0ee24e604704ed830d241c5e47ff"}, + {file = "numpy-1.20.1-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:032be656d89bbf786d743fee11d01ef318b0781281241997558fa7950028dd29"}, + {file = "numpy-1.20.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:66b467adfcf628f66ea4ac6430ded0614f5cc06ba530d09571ea404789064adc"}, + {file = "numpy-1.20.1-cp39-cp39-win32.whl", hash = "sha256:12e4ba5c6420917571f1a5becc9338abbde71dd811ce40b37ba62dec7b39af6d"}, + {file = "numpy-1.20.1-cp39-cp39-win_amd64.whl", hash = "sha256:9c94cab5054bad82a70b2e77741271790304651d584e2cdfe2041488e753863b"}, + {file = "numpy-1.20.1-pp37-pypy37_pp73-manylinux2010_x86_64.whl", hash = "sha256:9eb551d122fadca7774b97db8a112b77231dcccda8e91a5bc99e79890797175e"}, + {file = "numpy-1.20.1.zip", hash = "sha256:3bc63486a870294683980d76ec1e3efc786295ae00128f9ea38e2c6e74d5a60a"}, ] nvidia-ml-py3 = [ {file = "nvidia-ml-py3-7.352.0.tar.gz", hash = "sha256:390f02919ee9d73fe63a98c73101061a6b37fa694a793abf56673320f1f51277"}, @@ -2803,6 +2775,8 @@ pyyaml = [ {file = "PyYAML-5.3.1-cp37-cp37m-win_amd64.whl", hash = "sha256:73f099454b799e05e5ab51423c7bcf361c58d3206fa7b0d555426b1f4d9a3eaf"}, {file = "PyYAML-5.3.1-cp38-cp38-win32.whl", hash = "sha256:06a0d7ba600ce0b2d2fe2e78453a470b5a6e000a985dd4a4e54e436cc36b0e97"}, {file = "PyYAML-5.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:95f71d2af0ff4227885f7a6605c37fd53d3a106fcab511b8860ecca9fcf400ee"}, + {file = "PyYAML-5.3.1-cp39-cp39-win32.whl", hash = "sha256:ad9c67312c84def58f3c04504727ca879cb0013b2517c85a9a253f0cb6380c0a"}, + {file = "PyYAML-5.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:6034f55dab5fea9e53f436aa68fa3ace2634918e8b5994d82f3621c04ff5ed2e"}, {file = "PyYAML-5.3.1.tar.gz", hash = "sha256:b8eac752c5e14d3eca0e6dd9199cd627518cb5ec06add0de9d32baeee6fe645d"}, ] pyzmq = [ @@ -2822,11 +2796,13 @@ pyzmq = [ {file = "pyzmq-20.0.0-cp37-cp37m-win32.whl", hash = "sha256:c95dda497a7c1b1e734b5e8353173ca5dd7b67784d8821d13413a97856588057"}, {file = "pyzmq-20.0.0-cp37-cp37m-win_amd64.whl", hash = "sha256:cc09c5cd1a4332611c8564d65e6a432dc6db3e10793d0254da9fa1e31d9ffd6d"}, {file = "pyzmq-20.0.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6e24907857c80dc67692e31f5bf3ad5bf483ee0142cec95b3d47e2db8c43bdda"}, + {file = "pyzmq-20.0.0-cp38-cp38-macosx_11_0_x86_64.whl", hash = "sha256:53706f4a792cdae422121fb6a5e65119bad02373153364fc9d004cf6a90394de"}, {file = "pyzmq-20.0.0-cp38-cp38-manylinux1_i686.whl", hash = "sha256:895695be380f0f85d2e3ec5ccf68a93c92d45bd298567525ad5633071589872c"}, {file = "pyzmq-20.0.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:d92c7f41a53ece82b91703ea433c7d34143248cf0cead33aa11c5fc621c764bf"}, {file = "pyzmq-20.0.0-cp38-cp38-win32.whl", hash = "sha256:309d763d89ec1845c0e0fa14e1fb6558fd8c9ef05ed32baec27d7a8499cc7bb0"}, {file = "pyzmq-20.0.0-cp38-cp38-win_amd64.whl", hash = "sha256:0e554fd390021edbe0330b67226325a820b0319c5b45e1b0a59bf22ccc36e793"}, {file = "pyzmq-20.0.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:cfa54a162a7b32641665e99b2c12084555afe9fc8fe80ec8b2f71a57320d10e1"}, + {file = "pyzmq-20.0.0-cp39-cp39-macosx_11_0_x86_64.whl", hash = "sha256:dc2f48b575dff6edefd572f1ac84cf0c3f18ad5fcf13384de32df740a010594a"}, {file = "pyzmq-20.0.0-cp39-cp39-manylinux1_i686.whl", hash = "sha256:5efe02bdcc5eafcac0aab531292294298f0ab8d28ed43be9e507d0e09173d1a4"}, {file = "pyzmq-20.0.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:0af84f34f27b5c6a0e906c648bdf46d4caebf9c8e6e16db0728f30a58141cad6"}, {file = "pyzmq-20.0.0-cp39-cp39-win32.whl", hash = "sha256:c63fafd2556d218368c51d18588f8e6f8d86d09d493032415057faf6de869b34"}, @@ -3052,6 +3028,7 @@ stevedore = [ ] subprocess32 = [ {file = "subprocess32-3.5.4-cp27-cp27m-macosx_10_6_intel.whl", hash = "sha256:88e37c1aac5388df41cc8a8456bb49ebffd321a3ad4d70358e3518176de3a56b"}, + {file = "subprocess32-3.5.4-cp27-cp27mu-manylinux2014_x86_64.whl", hash = "sha256:e45d985aef903c5b7444d34350b05da91a9e0ea015415ab45a21212786c649d0"}, {file = "subprocess32-3.5.4.tar.gz", hash = "sha256:eb2937c80497978d181efa1b839ec2d9622cf9600a039a79d0e108d1f9aec79d"}, ] terminado = [ @@ -3071,24 +3048,32 @@ toml = [ {file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"}, ] torch = [ - {file = "torch-1.7.0-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:6b0c9b56cb56afe3ecbac79351d21c6f7172dffc7b7daa8c365f660541baf1a5"}, - {file = "torch-1.7.0-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:e8cc3b2c3937b7ae036a3b447a189af049bfc006bca054fc1d8ae78766ca3105"}, - {file = "torch-1.7.0-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:1520c48430dea38e5845b7b3defc9054edad45f1f245808aa268ade840bb2c2a"}, - {file = "torch-1.7.0-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:89cb8774243750bd3fd2b3b3d09bab6e3be68b1785ad48b8411f1eb4fc7acdba"}, - {file = "torch-1.7.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:11054f26eee5c3114d217201dba5b3a35f1745d11133c123c077c5981bc95997"}, - {file = "torch-1.7.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:b8000e39600e101b2f19dbbab75de663a3b78e3979c3e1720b7136aae1c35ce2"}, + {file = "torch-1.7.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:422e64e98d0e100c360993819d0307e5d56e9517b26135808ad68984d577d75a"}, + {file = "torch-1.7.1-cp36-cp36m-win_amd64.whl", hash = "sha256:f0aaf657145533824b15f2fd8fde8f8c67fe6c6281088ef588091f03fad90243"}, + {file = "torch-1.7.1-cp36-none-macosx_10_9_x86_64.whl", hash = "sha256:af464a6f4314a875035e0c4c2b07517599704b214634f4ed3ad2e748c5ef291f"}, + {file = "torch-1.7.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:5d76c255a41484c1d41a9ff570b9c9f36cb85df9428aa15a58ae16ac7cfc2ea6"}, + {file = "torch-1.7.1-cp37-cp37m-win_amd64.whl", hash = "sha256:d241c3f1c4d563e4ba86f84769c23e12606db167ee6f674eedff6d02901462e3"}, + {file = "torch-1.7.1-cp37-none-macosx_10_9_x86_64.whl", hash = "sha256:de84b4166e3f7335eb868b51d3bbd909ec33828af27290b4171bce832a55be3c"}, + {file = "torch-1.7.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:dd2fc6880c95e836960d86efbbc7f63d3287f2e1893c51d31f96dbfe02f0d73e"}, + {file = "torch-1.7.1-cp38-cp38-win_amd64.whl", hash = "sha256:e000b94be3aa58ad7f61e7d07cf379ea9366cf6c6874e68bd58ad0bdc537b3a7"}, + {file = "torch-1.7.1-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:2e49cac969976be63117004ee00d0a3e3dd4ea662ad77383f671b8992825de1a"}, + {file = "torch-1.7.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:a3793dcceb12b1e2281290cca1277c5ce86ddfd5bf044f654285a4d69057aea7"}, + {file = "torch-1.7.1-cp39-cp39-win_amd64.whl", hash = "sha256:6652a767a0572ae0feb74ad128758e507afd3b8396b6e7f147e438ba8d4c6f63"}, + {file = "torch-1.7.1-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:38d67f4fb189a92a977b2c0a38e4f6dd413e0bf55aa6d40004696df7e40a71ff"}, ] torch-summary = [ {file = "torch-summary-1.4.3.tar.gz", hash = "sha256:2dcbc1dfd07dca9f4080bcacdaf90db3f2fc28efee348c8fba9033039b0e8c82"}, {file = "torch_summary-1.4.3-py3-none-any.whl", hash = "sha256:a0a76916bd11d054fd3863dc7c474971922badfbc13d6404f9eddd297041f094"}, ] torchvision = [ - {file = "torchvision-0.8.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:80b1c6d0a97e86454c15cf9f1afcf0751761273b7687c3d0910336ea87cca8d4"}, - {file = "torchvision-0.8.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:307daa1daa4cc1a2380dd26f81d3a9670535fff8927f1049dc76d4e47253fb8e"}, - {file = "torchvision-0.8.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b58262a2bd2d419d94d7bf8aaa3a532b9283f4995e766723cc4cc3a52d8883c8"}, - {file = "torchvision-0.8.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:95b0ce59e631e2c97e6069dff126a43232cca859b18a1b505e5b02dd1a65dd0f"}, - {file = "torchvision-0.8.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:469e0b831bfe17c46159966b5dc7ba09c87eaeecbed6f9a4d6ec4e691b0c8827"}, - {file = "torchvision-0.8.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:337820e680e5193872903369d8177d5ea681e7156d370d89d487b0e0f1e56238"}, + {file = "torchvision-0.8.2-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:86fae370d222f76ad57c57c3bee03f78b8db727743bfb4c1559a3d395159cea8"}, + {file = "torchvision-0.8.2-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:951239b5fcb911dbf78c1385d677f5f48c7a1b12859e3d3ec287562821b17cf2"}, + {file = "torchvision-0.8.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:24db8f4c3d812a032273f68563ad5dbd724f5bfbed523d0c6dce8cede26bb153"}, + {file = "torchvision-0.8.2-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:b068f6bcbe91bdd34dda0a39e8a26392add45a3be82543f6dd523b76484fb56f"}, + {file = "torchvision-0.8.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:afb76a66b9b0693f758a881a2bf333ed97e3c0c3f15a413c4f49d8dd8bd21307"}, + {file = "torchvision-0.8.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cd8817e9197fc60ebae37162a445db90bbf35591314a5767ad3d1490b5d65b0f"}, + {file = "torchvision-0.8.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1bd58acc3366ec02266aae56a7a752d43ef07de4a6ba420c4f907d0c9168bb8c"}, + {file = "torchvision-0.8.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:976750a49db2e23dc5a1ed0b5c31f7af51ed2702eee410ee09ef985c3a3e48cf"}, ] tornado = [ {file = "tornado-6.1-cp35-cp35m-macosx_10_9_x86_64.whl", hash = "sha256:d371e811d6b156d82aa5f9a4e08b58debf97c302a35714f6f45e35139c332e32"}, @@ -3149,19 +3134,28 @@ typed-ast = [ {file = "typed_ast-1.4.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:269151951236b0f9a6f04015a9004084a5ab0d5f19b57de779f908621e7d8b75"}, {file = "typed_ast-1.4.1-cp36-cp36m-manylinux1_i686.whl", hash = "sha256:24995c843eb0ad11a4527b026b4dde3da70e1f2d8806c99b7b4a7cf491612652"}, {file = "typed_ast-1.4.1-cp36-cp36m-manylinux1_x86_64.whl", hash = "sha256:fe460b922ec15dd205595c9b5b99e2f056fd98ae8f9f56b888e7a17dc2b757e7"}, + {file = "typed_ast-1.4.1-cp36-cp36m-manylinux2014_aarch64.whl", hash = "sha256:fcf135e17cc74dbfbc05894ebca928ffeb23d9790b3167a674921db19082401f"}, {file = "typed_ast-1.4.1-cp36-cp36m-win32.whl", hash = "sha256:4e3e5da80ccbebfff202a67bf900d081906c358ccc3d5e3c8aea42fdfdfd51c1"}, {file = "typed_ast-1.4.1-cp36-cp36m-win_amd64.whl", hash = "sha256:249862707802d40f7f29f6e1aad8d84b5aa9e44552d2cc17384b209f091276aa"}, {file = "typed_ast-1.4.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8ce678dbaf790dbdb3eba24056d5364fb45944f33553dd5869b7580cdbb83614"}, {file = "typed_ast-1.4.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:c9e348e02e4d2b4a8b2eedb48210430658df6951fa484e59de33ff773fbd4b41"}, {file = "typed_ast-1.4.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:bcd3b13b56ea479b3650b82cabd6b5343a625b0ced5429e4ccad28a8973f301b"}, + {file = "typed_ast-1.4.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:f208eb7aff048f6bea9586e61af041ddf7f9ade7caed625742af423f6bae3298"}, {file = "typed_ast-1.4.1-cp37-cp37m-win32.whl", hash = "sha256:d5d33e9e7af3b34a40dc05f498939f0ebf187f07c385fd58d591c533ad8562fe"}, {file = "typed_ast-1.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:0666aa36131496aed8f7be0410ff974562ab7eeac11ef351def9ea6fa28f6355"}, {file = "typed_ast-1.4.1-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:d205b1b46085271b4e15f670058ce182bd1199e56b317bf2ec004b6a44f911f6"}, {file = "typed_ast-1.4.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:6daac9731f172c2a22ade6ed0c00197ee7cc1221aa84cfdf9c31defeb059a907"}, {file = "typed_ast-1.4.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:498b0f36cc7054c1fead3d7fc59d2150f4d5c6c56ba7fb150c013fbc683a8d2d"}, + {file = "typed_ast-1.4.1-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:7e4c9d7658aaa1fc80018593abdf8598bf91325af6af5cce4ce7c73bc45ea53d"}, {file = "typed_ast-1.4.1-cp38-cp38-win32.whl", hash = "sha256:715ff2f2df46121071622063fc7543d9b1fd19ebfc4f5c8895af64a77a8c852c"}, {file = "typed_ast-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:fc0fea399acb12edbf8a628ba8d2312f583bdbdb3335635db062fa98cf71fca4"}, {file = "typed_ast-1.4.1-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:d43943ef777f9a1c42bf4e552ba23ac77a6351de620aa9acf64ad54933ad4d34"}, + {file = "typed_ast-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:92c325624e304ebf0e025d1224b77dd4e6393f18aab8d829b5b7e04afe9b7a2c"}, + {file = "typed_ast-1.4.1-cp39-cp39-manylinux1_i686.whl", hash = "sha256:d648b8e3bf2fe648745c8ffcee3db3ff903d0817a01a12dd6a6ea7a8f4889072"}, + {file = "typed_ast-1.4.1-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:fac11badff8313e23717f3dada86a15389d0708275bddf766cca67a84ead3e91"}, + {file = "typed_ast-1.4.1-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:0d8110d78a5736e16e26213114a38ca35cb15b6515d535413b090bd50951556d"}, + {file = "typed_ast-1.4.1-cp39-cp39-win32.whl", hash = "sha256:b52ccf7cfe4ce2a1064b18594381bccf4179c2ecf7f513134ec2f993dd4ab395"}, + {file = "typed_ast-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:3742b32cf1c6ef124d57f95be609c473d7ec4c14d0090e5a5e05a15269fb4d0c"}, {file = "typed_ast-1.4.1.tar.gz", hash = "sha256:8c8aaad94455178e3187ab22c8b01a3837f8ee50e09cf31f1ba129eb293ec30b"}, ] typeguard = [ diff --git a/pyproject.toml b/pyproject.toml index 4c674bc..2f774b2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ marshmallow = "^3.6.0" sphinx-autodoc-typehints = "^1.10.3" sphinx_rtd_theme = "^0.4.3" boltons = "^20.1.0" -h5py = "^2.10.0" +h5py = "^3.2.1" toml = "^0.10.1" torch = "^1.7.0" torchvision = "^0.8.1" diff --git a/src/notebooks/00-testing-stuff-out.ipynb b/src/notebooks/00-testing-stuff-out.ipynb deleted file mode 100644 index 2d6b43c..0000000 --- a/src/notebooks/00-testing-stuff-out.ipynb +++ /dev/null @@ -1,1059 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "%matplotlib inline\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "from PIL import Image\n", - "import torch.nn.functional as F\n", - "import torch\n", - "from torch import nn\n", - "from torchsummary import summary\n", - "from importlib.util import find_spec\n", - "if find_spec(\"text_recognizer\") is None:\n", - " import sys\n", - " sys.path.append('..')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks import CNN, TDS2d" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tds2d = TDS2d(**{\n", - " \"depth\" : 4,\n", - " \"tds_groups\" : [\n", - " { \"channels\" : 4, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n", - " { \"channels\" : 32, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n", - " { \"channels\" : 64, \"num_blocks\" : 3, \"stride\" : [2, 2] },\n", - " { \"channels\" : 128, \"num_blocks\" : 3, \"stride\" : [2, 1] },\n", - " ],\n", - " \"kernel_size\" : [5, 7],\n", - " \"dropout_rate\" : 0.1\n", - " }, input_dim=32, output_dim=128)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tds2d" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "summary(tds2d, (1, 28, 952), device=\"cpu\", depth=3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t = torch.randn(2,1, 28, 952)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tds2d(t).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "cnn = CNN()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "i = nn.Sequential(nn.Conv2d(1,1,1,1))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "nn.Sequential(i,i)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "cnn(t).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks.vqvae import Encoder, Decoder, VQVAE" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "vqvae = VQVAE(1, [32, 128, 128, 256], [4, 4, 4, 4], [2, 2, [1, 2], [1, 2]], 2, 32, 256, [[6, 119], [7, 238]])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t = torch.randn(2, 1, 28, 952)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "x, l = vqvae(t)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "5 * 59 / 10" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "x.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "summary(vqvae, (1, 28, 952), device=\"cpu\", depth=3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "up = nn.Upsample([4, 59])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "up(tt).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "tt.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "class GEGLU(nn.Module):\n", - " def __init__(self, dim_in, dim_out):\n", - " super().__init__()\n", - " self.proj = nn.Linear(dim_in, dim_out * 2)\n", - "\n", - " def forward(self, x):\n", - " x, gate = self.proj(x).chunk(2, dim = -1)\n", - " return x * F.gelu(gate)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "e = GEGLU(256, 2048)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "e(t).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "emb = nn.Embedding(56, 256)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with torch.no_grad():\n", - " e = emb(torch.Tensor([55]).long())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from einops import repeat" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ee = repeat(e, \"() n -> b n\", b=16)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "emb.device" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ee" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "ee.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t = torch.randn(16, 10, 256)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t = torch.cat((ee.unsqueeze(1), t, ee.unsqueeze(1)), dim=1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "e.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks.residual_network import IdentityBlock, ResidualBlock, BasicBlock, BottleNeckBlock, ResidualLayer, ResidualNetwork, ResidualNetworkEncoder" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks import WideResidualNetwork" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "wr = WideResidualNetwork(\n", - " in_channels= 1,\n", - " num_classes= 80,\n", - " in_planes=64,\n", - " depth=10,\n", - " num_layers=4,\n", - " width_factor=2,\n", - " num_stages=[64, 128, 256, 256],\n", - " dropout_rate= 0.1,\n", - " activation= \"SELU\",\n", - " use_decoder= False,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from torchsummary import summary" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "backbone = ResidualNetworkEncoder(1, [64, 65, 66, 67, 68], [2, 2, 2, 2, 2])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "summary(backbone, (1, 28, 952), device=\"cpu\", depth=3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - " backbone = nn.Sequential(\n", - " *list(wr.children())[:][:]\n", - " )\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "backbone" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "summary(wr, (1, 28, 952), device=\"cpu\", depth=3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "a = torch.rand(1, 1, 28, 952)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "b = wr(a)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from einops import rearrange" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "b = rearrange(b, \"b c h w -> b w c h\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "c = nn.AdaptiveAvgPool2d((None, 1))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "d = c(b)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "d.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "d.squeeze(3).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "b.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from torch import nn" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "32 + 64" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "3 * 112" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "col_embed = nn.Parameter(torch.rand(1000, 256 // 2))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "W, H = 196, 4" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "col_embed[:W].unsqueeze(0).repeat(H, 1, 1).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "col_embed[:H].unsqueeze(1).repeat(1, W, 1).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - " torch.cat(\n", - " [\n", - " col_embed[:W].unsqueeze(0).repeat(H, 1, 1),\n", - " col_embed[:H].unsqueeze(1).repeat(1, W, 1),\n", - " ],\n", - " dim=-1,\n", - " ).unsqueeze(0).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "4 * 196" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "target = torch.tensor([1,1,12,1,1,1,1,1,9,9,9,9,9,9])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "torch.nonzero(target == 9, as_tuple=False)[0].item()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "target[:9]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "np.inf" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks.transformer.positional_encoding import PositionalEncoding" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "plt.figure(figsize=(15, 5))\n", - "pe = PositionalEncoding(20, 0)\n", - "y = pe.forward(torch.zeros(1, 100, 20))\n", - "plt.plot(np.arange(100), y[0, :, 4:8].data.numpy())\n", - "plt.legend([\"dim %d\"%p for p in [4,5,6,7]])\n", - "None" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks.densenet import DenseNet,_DenseLayer,_DenseBlock" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dnet = DenseNet(12, (6, 12, 10), 1, 24, 80, 4, 0, True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "216 / 8" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "summary(dnet, (1, 28, 952), device=\"cpu\", depth=3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - " backbone = nn.Sequential(\n", - " *list(dnet.children())[:][:-4]\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "backbone" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks import WideResidualNetwork" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "w = WideResidualNetwork(\n", - " in_channels = 1,\n", - " in_planes = 32,\n", - " num_classes = 80,\n", - " depth = 10,\n", - " width_factor = 1,\n", - " dropout_rate = 0.0,\n", - " num_layers = 5,\n", - " activation = \"relu\",\n", - " use_decoder = False,)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "summary(w, (1, 28, 952), device=\"cpu\", depth=2)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "sz= 5" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mask = torch.triu(torch.ones(sz, sz), 1)\n", - "mask = mask.masked_fill(mask==1, float('-inf'))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\n", - "h = torch.rand(1, 256, 10, 10)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "h.flatten(2).permute(2, 0, 1).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "h.flatten(2).permute(2, 0, 1).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mask\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pred = torch.Tensor([1,21,2,45,31, 81, 1, 79, 79, 79, 2,1,1,1,1, 81, 1, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()\n", - "target = torch.Tensor([1,1,1,1,1, 81, 79, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79, 1,1,1,1,1, 81, 79, 79, 79, 79]).long()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mask = (target != 79)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "mask" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pred * mask" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "target * mask" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.models.metrics import accuracy" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pad_indcies = torch.nonzero(target == 79, as_tuple=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t1 = torch.nonzero(target == 81, as_tuple=False).squeeze(1)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "target.shape[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t2 = torch.arange(10, target.shape[0] + 1, 10)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t2" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "for start, stop in zip(t1, t2):\n", - " pred[start+1:stop] = 79" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pred" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "[pred[start+1:stop] = 79 for start, stop in zip(t1, t2)]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "pad_indcies" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pred[pad_indcies:pad_indcies] = 79" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "pred.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "target.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "accuracy(pred, target)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "acc = (pred == target).sum().float() / target.shape[0]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "acc" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.2" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/src/text_recognizer/tests/support/emnist/8.png b/src/text_recognizer/tests/support/emnist/8.png Binary files differdeleted file mode 100644 index faa29aa..0000000 --- a/src/text_recognizer/tests/support/emnist/8.png +++ /dev/null diff --git a/src/text_recognizer/tests/support/emnist/U.png b/src/text_recognizer/tests/support/emnist/U.png Binary files differdeleted file mode 100644 index 304eaec..0000000 --- a/src/text_recognizer/tests/support/emnist/U.png +++ /dev/null diff --git a/src/text_recognizer/tests/support/emnist/e.png b/src/text_recognizer/tests/support/emnist/e.png Binary files differdeleted file mode 100644 index a03ecd4..0000000 --- a/src/text_recognizer/tests/support/emnist/e.png +++ /dev/null diff --git a/src/tasks/build_transitions.py b/tasks/build_transitions.py index 91f8c1a..91f8c1a 100644 --- a/src/tasks/build_transitions.py +++ b/tasks/build_transitions.py diff --git a/src/tasks/create_emnist_lines_datasets.sh b/tasks/create_emnist_lines_datasets.sh index 6416277..6416277 100755 --- a/src/tasks/create_emnist_lines_datasets.sh +++ b/tasks/create_emnist_lines_datasets.sh diff --git a/src/tasks/create_iam_paragraphs.sh b/tasks/create_iam_paragraphs.sh index fa2bfb0..fa2bfb0 100755 --- a/src/tasks/create_iam_paragraphs.sh +++ b/tasks/create_iam_paragraphs.sh diff --git a/src/tasks/download_emnist.sh b/tasks/download_emnist.sh index 18c8e29..18c8e29 100755 --- a/src/tasks/download_emnist.sh +++ b/tasks/download_emnist.sh diff --git a/src/tasks/download_iam.sh b/tasks/download_iam.sh index e3cf76b..e3cf76b 100755 --- a/src/tasks/download_iam.sh +++ b/tasks/download_iam.sh diff --git a/src/tasks/make_wordpieces.py b/tasks/make_wordpieces.py index 2ac0e2c..2ac0e2c 100644 --- a/src/tasks/make_wordpieces.py +++ b/tasks/make_wordpieces.py diff --git a/src/tasks/prepare_experiments.sh b/tasks/prepare_experiments.sh index 95a538f..95a538f 100755 --- a/src/tasks/prepare_experiments.sh +++ b/tasks/prepare_experiments.sh diff --git a/src/tasks/test_functionality.sh b/tasks/test_functionality.sh index 5ccf0cd..5ccf0cd 100755 --- a/src/tasks/test_functionality.sh +++ b/tasks/test_functionality.sh diff --git a/src/tasks/train.sh b/tasks/train.sh index 60cbd23..60cbd23 100755 --- a/src/tasks/train.sh +++ b/tasks/train.sh diff --git a/src/text_recognizer/__init__.py b/text_recognizer/__init__.py index 3dc1f76..3dc1f76 100644 --- a/src/text_recognizer/__init__.py +++ b/text_recognizer/__init__.py diff --git a/src/text_recognizer/character_predictor.py b/text_recognizer/character_predictor.py index ad71289..ad71289 100644 --- a/src/text_recognizer/character_predictor.py +++ b/text_recognizer/character_predictor.py diff --git a/src/text_recognizer/datasets/__init__.py b/text_recognizer/datasets/__init__.py index a6c1c59..a6c1c59 100644 --- a/src/text_recognizer/datasets/__init__.py +++ b/text_recognizer/datasets/__init__.py diff --git a/src/text_recognizer/datasets/dataset.py b/text_recognizer/datasets/dataset.py index e794605..e794605 100644 --- a/src/text_recognizer/datasets/dataset.py +++ b/text_recognizer/datasets/dataset.py diff --git a/src/text_recognizer/datasets/emnist_dataset.py b/text_recognizer/datasets/emnist_dataset.py index 9884fdf..9884fdf 100644 --- a/src/text_recognizer/datasets/emnist_dataset.py +++ b/text_recognizer/datasets/emnist_dataset.py diff --git a/src/text_recognizer/datasets/emnist_essentials.json b/text_recognizer/datasets/emnist_essentials.json index 2a0648a..2a0648a 100644 --- a/src/text_recognizer/datasets/emnist_essentials.json +++ b/text_recognizer/datasets/emnist_essentials.json diff --git a/src/text_recognizer/datasets/emnist_lines_dataset.py b/text_recognizer/datasets/emnist_lines_dataset.py index 1992446..1992446 100644 --- a/src/text_recognizer/datasets/emnist_lines_dataset.py +++ b/text_recognizer/datasets/emnist_lines_dataset.py diff --git a/src/text_recognizer/datasets/iam_dataset.py b/text_recognizer/datasets/iam_dataset.py index f4a869d..a8998b9 100644 --- a/src/text_recognizer/datasets/iam_dataset.py +++ b/text_recognizer/datasets/iam_dataset.py @@ -13,6 +13,7 @@ from text_recognizer.datasets.util import _download_raw_dataset, DATA_DIRNAME RAW_DATA_DIRNAME = DATA_DIRNAME / "raw" / "iam" METADATA_FILENAME = RAW_DATA_DIRNAME / "metadata.toml" EXTRACTED_DATASET_DIRNAME = RAW_DATA_DIRNAME / "iamdb" +RAW_DATA_DIRNAME.mkdir(parents=True, exist_ok=True) DOWNSAMPLE_FACTOR = 2 # If images were downsampled, the regions must also be. LINE_REGION_PADDING = 0 # Add this many pixels around the exact coordinates. diff --git a/src/text_recognizer/datasets/iam_lines_dataset.py b/text_recognizer/datasets/iam_lines_dataset.py index 1cb84bd..1cb84bd 100644 --- a/src/text_recognizer/datasets/iam_lines_dataset.py +++ b/text_recognizer/datasets/iam_lines_dataset.py diff --git a/src/text_recognizer/datasets/iam_paragraphs_dataset.py b/text_recognizer/datasets/iam_paragraphs_dataset.py index 8ba5142..8ba5142 100644 --- a/src/text_recognizer/datasets/iam_paragraphs_dataset.py +++ b/text_recognizer/datasets/iam_paragraphs_dataset.py diff --git a/src/text_recognizer/datasets/iam_preprocessor.py b/text_recognizer/datasets/iam_preprocessor.py index a93eb00..a93eb00 100644 --- a/src/text_recognizer/datasets/iam_preprocessor.py +++ b/text_recognizer/datasets/iam_preprocessor.py diff --git a/src/text_recognizer/datasets/sentence_generator.py b/text_recognizer/datasets/sentence_generator.py index dd76652..dd76652 100644 --- a/src/text_recognizer/datasets/sentence_generator.py +++ b/text_recognizer/datasets/sentence_generator.py diff --git a/src/text_recognizer/datasets/transforms.py b/text_recognizer/datasets/transforms.py index b6a48f5..b6a48f5 100644 --- a/src/text_recognizer/datasets/transforms.py +++ b/text_recognizer/datasets/transforms.py diff --git a/src/text_recognizer/datasets/util.py b/text_recognizer/datasets/util.py index da87756..da87756 100644 --- a/src/text_recognizer/datasets/util.py +++ b/text_recognizer/datasets/util.py diff --git a/src/text_recognizer/line_predictor.py b/text_recognizer/line_predictor.py index 8e348fe..8e348fe 100644 --- a/src/text_recognizer/line_predictor.py +++ b/text_recognizer/line_predictor.py diff --git a/src/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py index 7647d7e..7647d7e 100644 --- a/src/text_recognizer/models/__init__.py +++ b/text_recognizer/models/__init__.py diff --git a/src/text_recognizer/models/base.py b/text_recognizer/models/base.py index 70f4cdb..70f4cdb 100644 --- a/src/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py diff --git a/src/text_recognizer/models/character_model.py b/text_recognizer/models/character_model.py index f9944f3..f9944f3 100644 --- a/src/text_recognizer/models/character_model.py +++ b/text_recognizer/models/character_model.py diff --git a/src/text_recognizer/models/crnn_model.py b/text_recognizer/models/crnn_model.py index 1e01a83..1e01a83 100644 --- a/src/text_recognizer/models/crnn_model.py +++ b/text_recognizer/models/crnn_model.py diff --git a/src/text_recognizer/models/ctc_transformer_model.py b/text_recognizer/models/ctc_transformer_model.py index 25925f2..25925f2 100644 --- a/src/text_recognizer/models/ctc_transformer_model.py +++ b/text_recognizer/models/ctc_transformer_model.py diff --git a/src/text_recognizer/models/segmentation_model.py b/text_recognizer/models/segmentation_model.py index 613108a..613108a 100644 --- a/src/text_recognizer/models/segmentation_model.py +++ b/text_recognizer/models/segmentation_model.py diff --git a/src/text_recognizer/models/transformer_model.py b/text_recognizer/models/transformer_model.py index 3f63053..3f63053 100644 --- a/src/text_recognizer/models/transformer_model.py +++ b/text_recognizer/models/transformer_model.py diff --git a/src/text_recognizer/models/vqvae_model.py b/text_recognizer/models/vqvae_model.py index 70f6f1f..70f6f1f 100644 --- a/src/text_recognizer/models/vqvae_model.py +++ b/text_recognizer/models/vqvae_model.py diff --git a/src/text_recognizer/networks/__init__.py b/text_recognizer/networks/__init__.py index 1521355..1521355 100644 --- a/src/text_recognizer/networks/__init__.py +++ b/text_recognizer/networks/__init__.py diff --git a/src/text_recognizer/networks/beam.py b/text_recognizer/networks/beam.py index dccccdb..dccccdb 100644 --- a/src/text_recognizer/networks/beam.py +++ b/text_recognizer/networks/beam.py diff --git a/src/text_recognizer/networks/cnn.py b/text_recognizer/networks/cnn.py index 1807bb9..1807bb9 100644 --- a/src/text_recognizer/networks/cnn.py +++ b/text_recognizer/networks/cnn.py diff --git a/src/text_recognizer/networks/cnn_transformer.py b/text_recognizer/networks/cnn_transformer.py index a2d7926..9150b55 100644 --- a/src/text_recognizer/networks/cnn_transformer.py +++ b/text_recognizer/networks/cnn_transformer.py @@ -1,7 +1,7 @@ """A CNN-Transformer for image to text recognition.""" from typing import Dict, Optional, Tuple -from einops import rearrange, repeat +from einops import rearrange import torch from torch import nn from torch import Tensor diff --git a/src/text_recognizer/networks/crnn.py b/text_recognizer/networks/crnn.py index 778e232..778e232 100644 --- a/src/text_recognizer/networks/crnn.py +++ b/text_recognizer/networks/crnn.py diff --git a/src/text_recognizer/networks/ctc.py b/text_recognizer/networks/ctc.py index af9b700..af9b700 100644 --- a/src/text_recognizer/networks/ctc.py +++ b/text_recognizer/networks/ctc.py diff --git a/src/text_recognizer/networks/densenet.py b/text_recognizer/networks/densenet.py index 7dc58d9..7dc58d9 100644 --- a/src/text_recognizer/networks/densenet.py +++ b/text_recognizer/networks/densenet.py diff --git a/src/text_recognizer/networks/lenet.py b/text_recognizer/networks/lenet.py index 527e1a0..527e1a0 100644 --- a/src/text_recognizer/networks/lenet.py +++ b/text_recognizer/networks/lenet.py diff --git a/src/text_recognizer/networks/loss/__init__.py b/text_recognizer/networks/loss/__init__.py index b489264..b489264 100644 --- a/src/text_recognizer/networks/loss/__init__.py +++ b/text_recognizer/networks/loss/__init__.py diff --git a/src/text_recognizer/networks/loss/loss.py b/text_recognizer/networks/loss/loss.py index cf9fa0d..cf9fa0d 100644 --- a/src/text_recognizer/networks/loss/loss.py +++ b/text_recognizer/networks/loss/loss.py diff --git a/src/text_recognizer/networks/metrics.py b/text_recognizer/networks/metrics.py index 2605731..2605731 100644 --- a/src/text_recognizer/networks/metrics.py +++ b/text_recognizer/networks/metrics.py diff --git a/src/text_recognizer/networks/mlp.py b/text_recognizer/networks/mlp.py index 1101912..1101912 100644 --- a/src/text_recognizer/networks/mlp.py +++ b/text_recognizer/networks/mlp.py diff --git a/src/text_recognizer/networks/residual_network.py b/text_recognizer/networks/residual_network.py index c33f419..c33f419 100644 --- a/src/text_recognizer/networks/residual_network.py +++ b/text_recognizer/networks/residual_network.py diff --git a/src/text_recognizer/networks/stn.py b/text_recognizer/networks/stn.py index e9d216f..e9d216f 100644 --- a/src/text_recognizer/networks/stn.py +++ b/text_recognizer/networks/stn.py diff --git a/src/text_recognizer/networks/transducer/__init__.py b/text_recognizer/networks/transducer/__init__.py index 8c19a01..8c19a01 100644 --- a/src/text_recognizer/networks/transducer/__init__.py +++ b/text_recognizer/networks/transducer/__init__.py diff --git a/src/text_recognizer/networks/transducer/tds_conv.py b/text_recognizer/networks/transducer/tds_conv.py index 5fb8ba9..5fb8ba9 100644 --- a/src/text_recognizer/networks/transducer/tds_conv.py +++ b/text_recognizer/networks/transducer/tds_conv.py diff --git a/src/text_recognizer/networks/transducer/test.py b/text_recognizer/networks/transducer/test.py index cadcecc..cadcecc 100644 --- a/src/text_recognizer/networks/transducer/test.py +++ b/text_recognizer/networks/transducer/test.py diff --git a/src/text_recognizer/networks/transducer/transducer.py b/text_recognizer/networks/transducer/transducer.py index d7e3d08..d7e3d08 100644 --- a/src/text_recognizer/networks/transducer/transducer.py +++ b/text_recognizer/networks/transducer/transducer.py diff --git a/src/text_recognizer/networks/transformer/__init__.py b/text_recognizer/networks/transformer/__init__.py index 9febc88..9febc88 100644 --- a/src/text_recognizer/networks/transformer/__init__.py +++ b/text_recognizer/networks/transformer/__init__.py diff --git a/src/text_recognizer/networks/transformer/attention.py b/text_recognizer/networks/transformer/attention.py index cce1ecc..cce1ecc 100644 --- a/src/text_recognizer/networks/transformer/attention.py +++ b/text_recognizer/networks/transformer/attention.py diff --git a/src/text_recognizer/networks/transformer/positional_encoding.py b/text_recognizer/networks/transformer/positional_encoding.py index 1ba5537..1ba5537 100644 --- a/src/text_recognizer/networks/transformer/positional_encoding.py +++ b/text_recognizer/networks/transformer/positional_encoding.py diff --git a/src/text_recognizer/networks/transformer/transformer.py b/text_recognizer/networks/transformer/transformer.py index dd180c4..dd180c4 100644 --- a/src/text_recognizer/networks/transformer/transformer.py +++ b/text_recognizer/networks/transformer/transformer.py diff --git a/src/text_recognizer/networks/unet.py b/text_recognizer/networks/unet.py index 510910f..510910f 100644 --- a/src/text_recognizer/networks/unet.py +++ b/text_recognizer/networks/unet.py diff --git a/src/text_recognizer/networks/util.py b/text_recognizer/networks/util.py index 131a6b4..131a6b4 100644 --- a/src/text_recognizer/networks/util.py +++ b/text_recognizer/networks/util.py diff --git a/src/text_recognizer/networks/vit.py b/text_recognizer/networks/vit.py index efb3701..efb3701 100644 --- a/src/text_recognizer/networks/vit.py +++ b/text_recognizer/networks/vit.py diff --git a/src/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py index c673d96..c673d96 100644 --- a/src/text_recognizer/networks/vq_transformer.py +++ b/text_recognizer/networks/vq_transformer.py diff --git a/src/text_recognizer/networks/vqvae/__init__.py b/text_recognizer/networks/vqvae/__init__.py index 763953c..763953c 100644 --- a/src/text_recognizer/networks/vqvae/__init__.py +++ b/text_recognizer/networks/vqvae/__init__.py diff --git a/src/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index 8847aba..8847aba 100644 --- a/src/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py diff --git a/src/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index d3adac5..d3adac5 100644 --- a/src/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py diff --git a/src/text_recognizer/networks/vqvae/vector_quantizer.py b/text_recognizer/networks/vqvae/vector_quantizer.py index f92c7ee..f92c7ee 100644 --- a/src/text_recognizer/networks/vqvae/vector_quantizer.py +++ b/text_recognizer/networks/vqvae/vector_quantizer.py diff --git a/src/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py index 50448b4..50448b4 100644 --- a/src/text_recognizer/networks/vqvae/vqvae.py +++ b/text_recognizer/networks/vqvae/vqvae.py diff --git a/src/text_recognizer/networks/wide_resnet.py b/text_recognizer/networks/wide_resnet.py index b767778..b767778 100644 --- a/src/text_recognizer/networks/wide_resnet.py +++ b/text_recognizer/networks/wide_resnet.py diff --git a/src/text_recognizer/paragraph_text_recognizer.py b/text_recognizer/paragraph_text_recognizer.py index aa39662..aa39662 100644 --- a/src/text_recognizer/paragraph_text_recognizer.py +++ b/text_recognizer/paragraph_text_recognizer.py diff --git a/src/text_recognizer/tests/__init__.py b/text_recognizer/tests/__init__.py index 18ff212..18ff212 100644 --- a/src/text_recognizer/tests/__init__.py +++ b/text_recognizer/tests/__init__.py diff --git a/src/text_recognizer/tests/support/__init__.py b/text_recognizer/tests/support/__init__.py index a265ede..a265ede 100644 --- a/src/text_recognizer/tests/support/__init__.py +++ b/text_recognizer/tests/support/__init__.py diff --git a/src/text_recognizer/tests/support/create_emnist_lines_support_files.py b/text_recognizer/tests/support/create_emnist_lines_support_files.py index 9abe143..9abe143 100644 --- a/src/text_recognizer/tests/support/create_emnist_lines_support_files.py +++ b/text_recognizer/tests/support/create_emnist_lines_support_files.py diff --git a/src/text_recognizer/tests/support/create_emnist_support_files.py b/text_recognizer/tests/support/create_emnist_support_files.py index f9ff030..f9ff030 100644 --- a/src/text_recognizer/tests/support/create_emnist_support_files.py +++ b/text_recognizer/tests/support/create_emnist_support_files.py diff --git a/src/text_recognizer/tests/support/create_iam_lines_support_files.py b/text_recognizer/tests/support/create_iam_lines_support_files.py index 50f9e3d..50f9e3d 100644 --- a/src/text_recognizer/tests/support/create_iam_lines_support_files.py +++ b/text_recognizer/tests/support/create_iam_lines_support_files.py diff --git a/src/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png b/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png Binary files differindex b7d0618..b7d0618 100644 --- a/src/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png +++ b/text_recognizer/tests/support/emnist_lines/Knox Ky<eos>.png diff --git a/src/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png b/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png Binary files differindex 14a8cf3..14a8cf3 100644 --- a/src/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png +++ b/text_recognizer/tests/support/emnist_lines/ancillary beliefs and<eos>.png diff --git a/src/text_recognizer/tests/support/emnist_lines/they<eos>.png b/text_recognizer/tests/support/emnist_lines/they<eos>.png Binary files differindex 7f05951..7f05951 100644 --- a/src/text_recognizer/tests/support/emnist_lines/they<eos>.png +++ b/text_recognizer/tests/support/emnist_lines/they<eos>.png diff --git a/src/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png b/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png Binary files differindex 6eeb642..6eeb642 100644 --- a/src/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png +++ b/text_recognizer/tests/support/iam_lines/He rose from his breakfast-nook bench<eos>.png diff --git a/src/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png b/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png Binary files differindex 4974cf8..4974cf8 100644 --- a/src/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png +++ b/text_recognizer/tests/support/iam_lines/and came into the livingroom, where<eos>.png diff --git a/src/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png b/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png Binary files differindex a731245..a731245 100644 --- a/src/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png +++ b/text_recognizer/tests/support/iam_lines/his entrance. He came, almost falling<eos>.png diff --git a/src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg b/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg Binary files differindex d9753b6..d9753b6 100644 --- a/src/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg +++ b/text_recognizer/tests/support/iam_paragraphs/a01-000u.jpg diff --git a/src/text_recognizer/tests/test_character_predictor.py b/text_recognizer/tests/test_character_predictor.py index 01bda78..01bda78 100644 --- a/src/text_recognizer/tests/test_character_predictor.py +++ b/text_recognizer/tests/test_character_predictor.py diff --git a/src/text_recognizer/tests/test_line_predictor.py b/text_recognizer/tests/test_line_predictor.py index eede4d4..eede4d4 100644 --- a/src/text_recognizer/tests/test_line_predictor.py +++ b/text_recognizer/tests/test_line_predictor.py diff --git a/src/text_recognizer/tests/test_paragraph_text_recognizer.py b/text_recognizer/tests/test_paragraph_text_recognizer.py index 3e280b9..3e280b9 100644 --- a/src/text_recognizer/tests/test_paragraph_text_recognizer.py +++ b/text_recognizer/tests/test_paragraph_text_recognizer.py diff --git a/src/text_recognizer/util.py b/text_recognizer/util.py index b431e22..b431e22 100644 --- a/src/text_recognizer/util.py +++ b/text_recognizer/util.py diff --git a/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt b/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt index 344e0a3..344e0a3 100644 --- a/src/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt +++ b/text_recognizer/weights/CRNNModel_IamLinesDataset_ConvolutionalRecurrentNetwork_weights.pt diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt b/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt index f2dfd84..f2dfd84 100644 --- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt +++ b/text_recognizer/weights/CharacterModel_EmnistDataset_DenseNet_weights.pt diff --git a/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt b/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt index e1add8d..e1add8d 100644 --- a/src/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt +++ b/text_recognizer/weights/CharacterModel_EmnistDataset_WideResidualNetwork_weights.pt diff --git a/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt b/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt Binary files differindex d9ca01d..d9ca01d 100644 --- a/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt +++ b/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_FCN_weights.pt diff --git a/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt b/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt Binary files differindex 0af0e57..0af0e57 100644 --- a/src/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt +++ b/text_recognizer/weights/SegmentationModel_IamParagraphsDataset_UNet_weights.pt diff --git a/src/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt b/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt Binary files differindex b5295c2..b5295c2 100644 --- a/src/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt +++ b/text_recognizer/weights/VQVAEModel_IamLinesDataset_VQVAE_weights.pt diff --git a/src/training/experiments/default_config_emnist.yml b/training/experiments/default_config_emnist.yml index bf2ed0a..bf2ed0a 100644 --- a/src/training/experiments/default_config_emnist.yml +++ b/training/experiments/default_config_emnist.yml diff --git a/src/training/experiments/embedding_experiment.yml b/training/experiments/embedding_experiment.yml index 1e5f941..1e5f941 100644 --- a/src/training/experiments/embedding_experiment.yml +++ b/training/experiments/embedding_experiment.yml diff --git a/src/training/experiments/sample_experiment.yml b/training/experiments/sample_experiment.yml index 8f94475..8f94475 100644 --- a/src/training/experiments/sample_experiment.yml +++ b/training/experiments/sample_experiment.yml diff --git a/src/training/gpu_manager.py b/training/gpu_manager.py index ce1b3dd..ce1b3dd 100644 --- a/src/training/gpu_manager.py +++ b/training/gpu_manager.py diff --git a/src/training/prepare_experiments.py b/training/prepare_experiments.py index 21997af..21997af 100644 --- a/src/training/prepare_experiments.py +++ b/training/prepare_experiments.py diff --git a/src/training/run_experiment.py b/training/run_experiment.py index faafea6..faafea6 100644 --- a/src/training/run_experiment.py +++ b/training/run_experiment.py diff --git a/src/training/run_sweep.py b/training/run_sweep.py index a578592..a578592 100644 --- a/src/training/run_sweep.py +++ b/training/run_sweep.py diff --git a/src/training/sweep_emnist.yml b/training/sweep_emnist.yml index 48d7261..48d7261 100644 --- a/src/training/sweep_emnist.yml +++ b/training/sweep_emnist.yml diff --git a/src/training/sweep_emnist_resnet.yml b/training/sweep_emnist_resnet.yml index 19a3040..19a3040 100644 --- a/src/training/sweep_emnist_resnet.yml +++ b/training/sweep_emnist_resnet.yml diff --git a/src/training/trainer/__init__.py b/training/trainer/__init__.py index de41bfb..de41bfb 100644 --- a/src/training/trainer/__init__.py +++ b/training/trainer/__init__.py diff --git a/src/training/trainer/callbacks/__init__.py b/training/trainer/callbacks/__init__.py index 80c4177..80c4177 100644 --- a/src/training/trainer/callbacks/__init__.py +++ b/training/trainer/callbacks/__init__.py diff --git a/src/training/trainer/callbacks/base.py b/training/trainer/callbacks/base.py index 500b642..500b642 100644 --- a/src/training/trainer/callbacks/base.py +++ b/training/trainer/callbacks/base.py diff --git a/src/training/trainer/callbacks/checkpoint.py b/training/trainer/callbacks/checkpoint.py index a54e0a9..a54e0a9 100644 --- a/src/training/trainer/callbacks/checkpoint.py +++ b/training/trainer/callbacks/checkpoint.py diff --git a/src/training/trainer/callbacks/early_stopping.py b/training/trainer/callbacks/early_stopping.py index 02b431f..02b431f 100644 --- a/src/training/trainer/callbacks/early_stopping.py +++ b/training/trainer/callbacks/early_stopping.py diff --git a/src/training/trainer/callbacks/lr_schedulers.py b/training/trainer/callbacks/lr_schedulers.py index 630c434..630c434 100644 --- a/src/training/trainer/callbacks/lr_schedulers.py +++ b/training/trainer/callbacks/lr_schedulers.py diff --git a/src/training/trainer/callbacks/progress_bar.py b/training/trainer/callbacks/progress_bar.py index 6c4305a..6c4305a 100644 --- a/src/training/trainer/callbacks/progress_bar.py +++ b/training/trainer/callbacks/progress_bar.py diff --git a/src/training/trainer/callbacks/wandb_callbacks.py b/training/trainer/callbacks/wandb_callbacks.py index 552a4f4..552a4f4 100644 --- a/src/training/trainer/callbacks/wandb_callbacks.py +++ b/training/trainer/callbacks/wandb_callbacks.py diff --git a/src/training/trainer/train.py b/training/trainer/train.py index b770c94..b770c94 100644 --- a/src/training/trainer/train.py +++ b/training/trainer/train.py diff --git a/src/training/trainer/util.py b/training/trainer/util.py index 7cf1b45..7cf1b45 100644 --- a/src/training/trainer/util.py +++ b/training/trainer/util.py diff --git a/src/wandb/settings b/wandb/settings index eafb083..eafb083 100644 --- a/src/wandb/settings +++ b/wandb/settings |