From 4defc734b681071e19dd86404abd416d24330b9a Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Tue, 4 May 2021 23:24:00 +0200 Subject: Bug fix --- notebooks/00-scratch-pad.ipynb | 598 +++++++++++++++++++++ notebooks/00-testing-stuff-out.ipynb | 547 ------------------- .../networks/transformer/nystromer/nystromer.py | 34 +- 3 files changed, 616 insertions(+), 563 deletions(-) create mode 100644 notebooks/00-scratch-pad.ipynb delete mode 100644 notebooks/00-testing-stuff-out.ipynb diff --git a/notebooks/00-scratch-pad.ipynb b/notebooks/00-scratch-pad.ipynb new file mode 100644 index 0000000..d50fd59 --- /dev/null +++ b/notebooks/00-scratch-pad.ipynb @@ -0,0 +1,598 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "%autoreload 2\n", + "\n", + "%matplotlib inline\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from PIL import Image\n", + "import torch.nn.functional as F\n", + "import torch\n", + "from torch import nn\n", + "from torchsummary import summary\n", + "from importlib.util import find_spec\n", + "if find_spec(\"text_recognizer\") is None:\n", + " import sys\n", + " sys.path.append('..')" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "ModuleList(\n", + " (0): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (1): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (2): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (3): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (4): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (5): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (6): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (7): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (8): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + " (9): ModuleList(\n", + " (0): Linear(in_features=10, out_features=10, bias=True)\n", + " )\n", + ")" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "nn.ModuleList([nn.ModuleList([nn.Linear(10, 10)]) for _ in range(10)])" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from omegaconf import OmegaConf" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "path = \"../training/configs/vqvae.yaml\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "conf = OmegaConf.load(path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(OmegaConf.to_yaml(conf))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks import VQVAE" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vae = VQVAE(**conf.network.args)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vae" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "datum = torch.randn([2, 1, 576, 640])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vae.encoder(datum)[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "vae(datum)[0].shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks.backbones.efficientnet import EfficientNet" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "en = EfficientNet()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "datum = torch.randn([2, 1, 576, 640])" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "trg = torch.randint(0, 1000, [2, 682])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trg.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "datum = torch.randn([2, 1, 224, 224])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "en(datum).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "path = \"../training/configs/cnn_transformer.yaml\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "conf = OmegaConf.load(path)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "seed: 4711\n", + "network:\n", + " desc: Configuration of the PyTorch neural network.\n", + " type: CNNTransformer\n", + " args:\n", + " encoder:\n", + " type: EfficientNet\n", + " args: null\n", + " num_decoder_layers: 4\n", + " hidden_dim: 256\n", + " num_heads: 4\n", + " expansion_dim: 1024\n", + " dropout_rate: 0.1\n", + " transformer_activation: glu\n", + "model:\n", + " desc: Configuration of the PyTorch Lightning model.\n", + " type: LitTransformerModel\n", + " args:\n", + " optimizer:\n", + " type: MADGRAD\n", + " args:\n", + " lr: 0.001\n", + " momentum: 0.9\n", + " weight_decay: 0\n", + " eps: 1.0e-06\n", + " lr_scheduler:\n", + " type: OneCycleLR\n", + " args:\n", + " interval: step\n", + " max_lr: 0.001\n", + " three_phase: true\n", + " epochs: 512\n", + " steps_per_epoch: 1246\n", + " criterion:\n", + " type: CrossEntropyLoss\n", + " args:\n", + " weight: None\n", + " ignore_index: -100\n", + " reduction: mean\n", + " monitor: val_loss\n", + " mapping: sentence_piece\n", + "data:\n", + " desc: Configuration of the training/test data.\n", + " type: IAMExtendedParagraphs\n", + " args:\n", + " batch_size: 16\n", + " num_workers: 12\n", + " train_fraction: 0.8\n", + " augment: true\n", + "callbacks:\n", + "- type: ModelCheckpoint\n", + " args:\n", + " monitor: val_loss\n", + " mode: min\n", + " save_last: true\n", + "- type: StochasticWeightAveraging\n", + " args:\n", + " swa_epoch_start: 0.8\n", + " swa_lrs: 0.05\n", + " annealing_epochs: 10\n", + " annealing_strategy: cos\n", + " device: null\n", + "- type: LearningRateMonitor\n", + " args:\n", + " logging_interval: step\n", + "- type: EarlyStopping\n", + " args:\n", + " monitor: val_loss\n", + " mode: min\n", + " patience: 10\n", + "trainer:\n", + " desc: Configuration of the PyTorch Lightning Trainer.\n", + " args:\n", + " stochastic_weight_avg: true\n", + " auto_scale_batch_size: binsearch\n", + " gradient_clip_val: 0\n", + " fast_dev_run: false\n", + " gpus: 1\n", + " precision: 16\n", + " max_epochs: 512\n", + " terminate_on_nan: true\n", + " weights_summary: true\n", + "load_checkpoint: null\n", + "\n" + ] + } + ], + "source": [ + "print(OmegaConf.to_yaml(conf))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "from text_recognizer.networks.cnn_transformer import CNNTransformer" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "t = CNNTransformer(input_shape=(1, 576, 640), output_shape=(682, 1), **conf.network.args)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "t.encode(datum).shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "trg.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([2, 682, 1004])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t(datum, trg).shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "b, n = 16, 128\n", + "device = \"cpu\"" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "x = lambda: torch.ones((b, n), device=device).bool()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([16, 128])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x().shape" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([16, 128])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.ones((b, n), device=device).bool().shape" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "x = torch.randn(1, 1, 576, 640)" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "144" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "576 // 4" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "160" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "640 // 4" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "x = torch.randn(1, 1, 144, 160)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from einops import rearrange" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "patch_size=16\n", + "p = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([1, 1440, 256])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "p.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/notebooks/00-testing-stuff-out.ipynb b/notebooks/00-testing-stuff-out.ipynb deleted file mode 100644 index 12c5145..0000000 --- a/notebooks/00-testing-stuff-out.ipynb +++ /dev/null @@ -1,547 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "%autoreload 2\n", - "\n", - "%matplotlib inline\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "from PIL import Image\n", - "import torch.nn.functional as F\n", - "import torch\n", - "from torch import nn\n", - "from torchsummary import summary\n", - "from importlib.util import find_spec\n", - "if find_spec(\"text_recognizer\") is None:\n", - " import sys\n", - " sys.path.append('..')" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "from omegaconf import OmegaConf" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "path = \"../training/configs/vqvae.yaml\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "conf = OmegaConf.load(path)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "print(OmegaConf.to_yaml(conf))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks import VQVAE" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "vae = VQVAE(**conf.network.args)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "vae" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "datum = torch.randn([2, 1, 576, 640])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "vae.encoder(datum)[0].shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "vae(datum)[0].shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks.backbones.efficientnet import EfficientNet" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "en = EfficientNet()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "datum = torch.randn([2, 1, 576, 640])" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "trg = torch.randint(0, 1000, [2, 682])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trg.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "datum = torch.randn([2, 1, 224, 224])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "en(datum).shape" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "path = \"../training/configs/cnn_transformer.yaml\"" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "conf = OmegaConf.load(path)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "seed: 4711\n", - "network:\n", - " desc: Configuration of the PyTorch neural network.\n", - " type: CNNTransformer\n", - " args:\n", - " encoder:\n", - " type: EfficientNet\n", - " args: null\n", - " num_decoder_layers: 4\n", - " hidden_dim: 256\n", - " num_heads: 4\n", - " expansion_dim: 1024\n", - " dropout_rate: 0.1\n", - " transformer_activation: glu\n", - "model:\n", - " desc: Configuration of the PyTorch Lightning model.\n", - " type: LitTransformerModel\n", - " args:\n", - " optimizer:\n", - " type: MADGRAD\n", - " args:\n", - " lr: 0.001\n", - " momentum: 0.9\n", - " weight_decay: 0\n", - " eps: 1.0e-06\n", - " lr_scheduler:\n", - " type: OneCycleLR\n", - " args:\n", - " interval: step\n", - " max_lr: 0.001\n", - " three_phase: true\n", - " epochs: 512\n", - " steps_per_epoch: 1246\n", - " criterion:\n", - " type: CrossEntropyLoss\n", - " args:\n", - " weight: None\n", - " ignore_index: -100\n", - " reduction: mean\n", - " monitor: val_loss\n", - " mapping: sentence_piece\n", - "data:\n", - " desc: Configuration of the training/test data.\n", - " type: IAMExtendedParagraphs\n", - " args:\n", - " batch_size: 16\n", - " num_workers: 12\n", - " train_fraction: 0.8\n", - " augment: true\n", - "callbacks:\n", - "- type: ModelCheckpoint\n", - " args:\n", - " monitor: val_loss\n", - " mode: min\n", - " save_last: true\n", - "- type: StochasticWeightAveraging\n", - " args:\n", - " swa_epoch_start: 0.8\n", - " swa_lrs: 0.05\n", - " annealing_epochs: 10\n", - " annealing_strategy: cos\n", - " device: null\n", - "- type: LearningRateMonitor\n", - " args:\n", - " logging_interval: step\n", - "- type: EarlyStopping\n", - " args:\n", - " monitor: val_loss\n", - " mode: min\n", - " patience: 10\n", - "trainer:\n", - " desc: Configuration of the PyTorch Lightning Trainer.\n", - " args:\n", - " stochastic_weight_avg: true\n", - " auto_scale_batch_size: binsearch\n", - " gradient_clip_val: 0\n", - " fast_dev_run: false\n", - " gpus: 1\n", - " precision: 16\n", - " max_epochs: 512\n", - " terminate_on_nan: true\n", - " weights_summary: true\n", - "load_checkpoint: null\n", - "\n" - ] - } - ], - "source": [ - "print(OmegaConf.to_yaml(conf))" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "from text_recognizer.networks.cnn_transformer import CNNTransformer" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [], - "source": [ - "t = CNNTransformer(input_shape=(1, 576, 640), output_shape=(682, 1), **conf.network.args)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "t.encode(datum).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "trg.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([2, 682, 1004])" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "t(datum, trg).shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [], - "source": [ - "b, n = 16, 128\n", - "device = \"cpu\"" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "x = lambda: torch.ones((b, n), device=device).bool()" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([16, 128])" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "x().shape" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([16, 128])" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "torch.ones((b, n), device=device).bool().shape" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "x = torch.randn(1, 1, 576, 640)" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "144" - ] - }, - "execution_count": 30, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "576 // 4" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "160" - ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "640 // 4" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "metadata": {}, - "outputs": [], - "source": [ - "x = torch.randn(1, 1, 144, 160)" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "from einops import rearrange" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "patch_size=16\n", - "p = rearrange(x, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([1, 1440, 256])" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "p.shape" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.9.1" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} diff --git a/text_recognizer/networks/transformer/nystromer/nystromer.py b/text_recognizer/networks/transformer/nystromer/nystromer.py index 0283d69..7cc889e 100644 --- a/text_recognizer/networks/transformer/nystromer/nystromer.py +++ b/text_recognizer/networks/transformer/nystromer/nystromer.py @@ -30,24 +30,26 @@ class Nystromer(nn.Module): super().__init__() self.layers = nn.ModuleList( [ - [ - PreNorm( - dim, - NystromAttention( - dim=dim, - dim_head=dim_head, - num_heads=num_heads, - num_landmarks=num_landmarks, - inverse_iter=inverse_iter, - residual=residual, - residual_conv_kernel=residual_conv_kernel, - dropout_rate=dropout_rate, + nn.ModuleList( + [ + PreNorm( + dim, + NystromAttention( + dim=dim, + dim_head=dim_head, + num_heads=num_heads, + num_landmarks=num_landmarks, + inverse_iter=inverse_iter, + residual=residual, + residual_conv_kernel=residual_conv_kernel, + dropout_rate=dropout_rate, + ), ), - ), - PreNorm(dim, FeedForward(dim=dim, dropout_rate=dropout_rate)), - ] + PreNorm(dim, FeedForward(dim=dim, dropout_rate=dropout_rate)), + ] + ) + for _ in range(depth) ] - for _ in range(depth) ) def forward(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: -- cgit v1.2.3-70-g09d2