summaryrefslogtreecommitdiff
path: root/notebooks
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-25 22:29:55 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-25 22:29:55 +0200
commitf78ad6e6adee4c90ad1b29d6058ece186bb423a4 (patch)
treedaf7c4972946de0009c839e83691b65c84b1550a /notebooks
parentfff7967447dc67b5340200760356b2de85b3969a (diff)
Update notebooks
Diffstat (limited to 'notebooks')
-rw-r--r--notebooks/03-look-at-iam-paragraphs.ipynb119
-rw-r--r--notebooks/04-efficientnet-transformer.ipynb256
2 files changed, 12 insertions, 363 deletions
diff --git a/notebooks/03-look-at-iam-paragraphs.ipynb b/notebooks/03-look-at-iam-paragraphs.ipynb
index f57d491..b55aa12 100644
--- a/notebooks/03-look-at-iam-paragraphs.ipynb
+++ b/notebooks/03-look-at-iam-paragraphs.ipynb
@@ -275,8 +275,7 @@
"source": [
"# context initialization\n",
"with initialize(config_path=\"../training/conf/\"):\n",
- " cfg = compose(config_name=\"config\", overrides=[\"+experiment=cnn_transformer_paragraphs\"])\n",
- " print(OmegaConf.to_yaml(cfg))"
+ " cfg = compose(config_name=\"config\", overrides=[\"+experiment=cnn_transformer_paragraphs\"])"
]
},
{
@@ -471,122 +470,6 @@
" x, y = dataset[i]\n",
" _plot(x[0], vmax=1, title=convert_y_label_to_string(y, datamodule.mapping))"
]
- },
- {
- "cell_type": "code",
- "execution_count": 23,
- "id": "6f53950a-6858-40b6-ad4a-2857227d9d59",
- "metadata": {},
- "outputs": [],
- "source": [
- "a = torch.randn(2, 1, 576, 640), torch.randn(2, 1, 576, 640)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "id": "580b9bbc-b213-4ef8-aefb-cfb453d08a44",
- "metadata": {},
- "outputs": [],
- "source": [
- "from text_recognizer.data.transforms.load_transform import load_transforms"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "id": "15744542-0880-45f2-881b-d81e04668305",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Compose(\n",
- " <text_recognizer.data.transforms.embed_crop.EmbedCrop object at 0x7f830edf5df0>\n",
- " ColorJitter(brightness=[0.8, 1.6], contrast=None, saturation=None, hue=None)\n",
- " RandomAffine(degrees=[-1.0, 1.0], shear=[-30.0, 20.0], interpolation=bilinear)\n",
- " ToTensor()\n",
- ")"
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "load_transforms(\"iam_lines.yaml\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 16,
- "id": "5d7eab42-c407-4b88-9492-e9279a38232a",
- "metadata": {},
- "outputs": [],
- "source": [
- "from torchvision.transforms import ColorJitter"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 17,
- "id": "3d02a3fe-1128-416f-80e1-84c9287e613d",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "ColorJitter(brightness=[0.5, 1.0], contrast=None, saturation=None, hue=None)"
- ]
- },
- "execution_count": 17,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "ColorJitter(brightness=[0.5, 1.0])"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 45,
- "id": "f0ead6d1-3093-4a42-a3b2-b3cdea75fc21",
- "metadata": {},
- "outputs": [],
- "source": [
- "import torchvision"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 46,
- "id": "f4c1606d-a063-465d-bf22-61a1cbc14ab9",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<InterpolationMode.BILINEAR: 'bilinear'>"
- ]
- },
- "execution_count": 46,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "getattr(torchvision.transforms.functional.InterpolationMode, \"BILINEAR\")"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "id": "617568a7-fde1-4f60-80c5-922d764f0c52",
- "metadata": {},
- "outputs": [],
- "source": []
}
],
"metadata": {
diff --git a/notebooks/04-efficientnet-transformer.ipynb b/notebooks/04-efficientnet-transformer.ipynb
index 8affa9d..145fe78 100644
--- a/notebooks/04-efficientnet-transformer.ipynb
+++ b/notebooks/04-efficientnet-transformer.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 1,
+ "execution_count": null,
"id": "7c02ae76-b540-4b16-9492-e9210b3b9249",
"metadata": {},
"outputs": [],
@@ -28,7 +28,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": null,
"id": "ccdb6dde-47e5-429a-88f2-0764fb7e259a",
"metadata": {},
"outputs": [],
@@ -40,7 +40,7 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": null,
"id": "3cf50475-39f2-4642-a7d1-5bcbc0a036f7",
"metadata": {},
"outputs": [],
@@ -50,7 +50,7 @@
},
{
"cell_type": "code",
- "execution_count": 7,
+ "execution_count": null,
"id": "e52ecb01-c975-4e55-925d-1182c7aea473",
"metadata": {},
"outputs": [],
@@ -61,28 +61,17 @@
},
{
"cell_type": "code",
- "execution_count": 8,
+ "execution_count": null,
"id": "f939aa37-7b1d-45cc-885c-323c4540bda1",
"metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "{'defaults': [{'override /mapping': None}, {'override /criterion': None}, {'override /datamodule': None}, {'override /network': None}, {'override /model': None}, {'override /lr_schedulers': None}, {'override /optimizers': None}], 'ignore_index': 3, 'num_classes': 58, 'max_output_len': 89, 'criterion': {'_target_': 'torch.nn.CrossEntropyLoss', 'ignore_index': 3}, 'mapping': {'_target_': 'text_recognizer.data.emnist_mapping.EmnistMapping'}, 'callbacks': {'stochastic_weight_averaging': {'_target_': 'pytorch_lightning.callbacks.StochasticWeightAveraging', 'swa_epoch_start': 0.8, 'swa_lrs': 0.05, 'annealing_epochs': 10, 'annealing_strategy': 'cos', 'device': None}}, 'optimizers': {'madgrad': {'_target_': 'madgrad.MADGRAD', 'lr': 0.0003, 'momentum': 0.9, 'weight_decay': 0, 'eps': 1e-06, 'parameters': 'network'}}, 'lr_schedulers': {'network': {'_target_': 'torch.optim.lr_scheduler.ReduceLROnPlateau', 'mode': 'min', 'factor': 0.5, 'patience': 10, 'threshold': 0.0001, 'threshold_mode': 'rel', 'cooldown': 0, 'min_lr': 1e-06, 'eps': 1e-08, 'interval': 'epoch', 'monitor': 'val/loss'}}, 'datamodule': {'_target_': 'text_recognizer.data.iam_lines.IAMLines', 'batch_size': 16, 'num_workers': 12, 'train_fraction': 0.8, 'augment': True, 'pin_memory': True, 'word_pieces': False}, 'network': {'_target_': 'text_recognizer.networks.conv_transformer.ConvTransformer', 'input_dims': [1, 56, 1024], 'hidden_dim': 128, 'encoder_dim': 1280, 'dropout_rate': 0.2, 'num_classes': 58, 'pad_index': 3, 'encoder': {'_target_': 'text_recognizer.networks.encoders.efficientnet.EfficientNet', 'arch': 'b0', 'out_channels': 1280, 'stochastic_dropout_rate': 0.2, 'bn_momentum': 0.99, 'bn_eps': 0.001}, 'decoder': {'_target_': 'text_recognizer.networks.transformer.Decoder', 'dim': 128, 'depth': 3, 'num_heads': 4, 'attn_fn': 'text_recognizer.networks.transformer.attention.Attention', 'attn_kwargs': {'dim_head': 32, 'dropout_rate': 0.2}, 'norm_fn': 'text_recognizer.networks.transformer.norm.ScaleNorm', 'ff_fn': 'text_recognizer.networks.transformer.mlp.FeedForward', 'ff_kwargs': {'dim_out': None, 'expansion_factor': 4, 'glu': True, 'dropout_rate': 0.2}, 'cross_attend': True, 'pre_norm': True, 'rotary_emb': {'_target_': 'text_recognizer.networks.transformer.positional_encodings.rotary_embedding.RotaryEmbedding', 'dim': 32}}}, 'model': {'_target_': 'text_recognizer.models.transformer.TransformerLitModel', 'max_output_len': 89, 'start_token': '<s>', 'end_token': '<e>', 'pad_token': '<p>'}, 'trainer': {'_target_': 'pytorch_lightning.Trainer', 'stochastic_weight_avg': True, 'auto_scale_batch_size': 'binsearch', 'auto_lr_find': False, 'gradient_clip_val': 0.5, 'fast_dev_run': False, 'gpus': 1, 'precision': 16, 'max_epochs': 420, 'terminate_on_nan': True, 'weights_summary': None, 'limit_train_batches': 1.0, 'limit_val_batches': 1.0, 'limit_test_batches': 1.0, 'resume_from_checkpoint': None, 'accumulate_grad_batches': 8, 'overfit_batches': 0}, 'summary': [[1, 1, 56, 1024], [1, 89]]}"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"cfg"
]
},
{
"cell_type": "code",
- "execution_count": 9,
+ "execution_count": null,
"id": "aaeab329-aeb0-4a1b-aa35-5a2aab81b1d0",
"metadata": {},
"outputs": [],
@@ -92,7 +81,7 @@
},
{
"cell_type": "code",
- "execution_count": 10,
+ "execution_count": null,
"id": "618b997c-e6a6-4487-b70c-9d260cb556d3",
"metadata": {},
"outputs": [],
@@ -102,233 +91,10 @@
},
{
"cell_type": "code",
- "execution_count": 16,
+ "execution_count": null,
"id": "25759b7b-8deb-4163-b75d-a1357c9fe88f",
"metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "tensor([[[[-0.1375, 0.1336, -0.9556, ..., 1.3900, -0.1311, 0.0424],\n",
- " [ 0.3474, 0.3920, -0.7718, ..., 1.3900, -0.1311, 0.0424],\n",
- " [ 0.5129, 0.5296, -0.5113, ..., 1.3900, -0.1311, 0.0424],\n",
- " ...,\n",
- " [-0.4101, -0.5381, 0.8383, ..., 1.3900, -0.1311, 0.0424],\n",
- " [-0.4902, -0.4792, 0.9942, ..., 1.3900, -0.1311, 0.0424],\n",
- " [-0.1197, -0.2727, 1.0515, ..., 1.3900, -0.1311, 0.0424]],\n",
- "\n",
- " [[-0.5087, -0.9362, 0.4579, ..., -0.9192, 0.5128, -0.4577],\n",
- " [-0.9311, -0.7565, 0.2538, ..., -0.9192, 0.5128, -0.4577],\n",
- " [-0.4974, -0.3438, 0.0245, ..., -0.9192, 0.5128, -0.4577],\n",
- " ...,\n",
- " [ 0.9153, 0.2433, -0.7304, ..., -0.9192, 0.5128, -0.4577],\n",
- " [ 0.3510, -0.2775, -0.7340, ..., -0.9192, 0.5128, -0.4577],\n",
- " [-0.5360, -0.7128, -0.6648, ..., -0.9192, 0.5128, -0.4577]],\n",
- "\n",
- " [[-0.3042, 0.4230, -0.1477, ..., 0.5753, 0.2627, -0.9272],\n",
- " [-0.6520, 0.1020, -0.3723, ..., 0.5753, 0.2627, -0.9272],\n",
- " [-0.4004, -0.2504, -0.5599, ..., 0.5753, 0.2627, -0.9272],\n",
- " ...,\n",
- " [ 0.6519, 0.3150, -0.5875, ..., 0.5753, 0.2627, -0.9272],\n",
- " [ 0.3030, 0.5633, -0.4083, ..., 0.5753, 0.2627, -0.9272],\n",
- " [-0.3245, 0.6381, -0.1886, ..., 0.5753, 0.2627, -0.9272]],\n",
- "\n",
- " [[-0.3549, 0.1346, 0.7488, ..., -0.0206, 0.3056, -0.3023],\n",
- " [-0.7494, 0.0327, 0.8030, ..., -0.0206, 0.3056, -0.3023],\n",
- " [-0.4549, -0.0792, 0.7776, ..., -0.0206, 0.3056, -0.3023],\n",
- " ...,\n",
- " [ 0.7482, 0.0998, -0.0949, ..., -0.0206, 0.3056, -0.3023],\n",
- " [ 0.3424, 0.1788, -0.3386, ..., -0.0206, 0.3056, -0.3023],\n",
- " [-0.3782, 0.2027, -0.5487, ..., -0.0206, 0.3056, -0.3023]]]])\n",
- "tensor([[[[-0.8664, -0.5805, 0.3959, ..., -0.3215, -1.1297, -1.0401],\n",
- " [-0.8629, -0.5843, 0.3928, ..., -0.3136, -1.1305, -1.0422],\n",
- " [-0.8580, -0.5812, 0.3899, ..., -0.3078, -1.1307, -1.0447],\n",
- " ...,\n",
- " [-0.8821, -0.5753, 0.3974, ..., -0.3362, -1.1056, -1.0572],\n",
- " [-0.8827, -0.5753, 0.3972, ..., -0.3362, -1.1055, -1.0570],\n",
- " [-0.8833, -0.5757, 0.3969, ..., -0.3359, -1.1051, -1.0570]],\n",
- "\n",
- " [[ 0.4034, -0.0638, 0.4687, ..., 0.2240, -0.5357, -0.2091],\n",
- " [ 0.3946, -0.0698, 0.4697, ..., 0.2241, -0.5370, -0.2128],\n",
- " [ 0.3872, -0.0761, 0.4697, ..., 0.2269, -0.5366, -0.2166],\n",
- " ...,\n",
- " [ 0.4107, -0.0766, 0.4126, ..., 0.2256, -0.5160, -0.2184],\n",
- " [ 0.4108, -0.0764, 0.4126, ..., 0.2257, -0.5160, -0.2182],\n",
- " [ 0.4109, -0.0761, 0.4126, ..., 0.2256, -0.5160, -0.2181]],\n",
- "\n",
- " [[ 1.1814, -0.3463, 0.7075, ..., 0.3644, 0.1285, 0.1048],\n",
- " [ 1.1817, -0.3528, 0.7070, ..., 0.3619, 0.1285, 0.1008],\n",
- " [ 1.1820, -0.3611, 0.7079, ..., 0.3566, 0.1281, 0.0932],\n",
- " ...,\n",
- " [ 1.1625, -0.3667, 0.7248, ..., 0.3525, 0.1328, 0.0854],\n",
- " [ 1.1626, -0.3665, 0.7251, ..., 0.3525, 0.1329, 0.0860],\n",
- " [ 1.1626, -0.3658, 0.7254, ..., 0.3526, 0.1329, 0.0866]],\n",
- "\n",
- " [[-0.3258, -0.0492, -0.1981, ..., -0.0680, -0.1737, 0.8813],\n",
- " [-0.3236, -0.0472, -0.1983, ..., -0.0716, -0.1781, 0.8754],\n",
- " [-0.3207, -0.0424, -0.2006, ..., -0.0735, -0.1775, 0.8708],\n",
- " ...,\n",
- " [-0.3556, -0.0327, -0.2103, ..., -0.0663, -0.1694, 0.9025],\n",
- " [-0.3566, -0.0330, -0.2102, ..., -0.0662, -0.1690, 0.9030],\n",
- " [-0.3577, -0.0334, -0.2098, ..., -0.0664, -0.1692, 0.9035]]]])\n",
- "tensor([[[[ 0.4187, -0.4013, 0.1790, ..., -0.7965, -0.4432, -0.4109],\n",
- " [ 0.4706, -0.2460, 0.2507, ..., -0.7946, -0.4455, -0.4032],\n",
- " [ 0.0920, -0.0103, 0.3006, ..., -0.7921, -0.4488, -0.3960],\n",
- " ...,\n",
- " [-0.4408, -0.0365, 0.1528, ..., -0.8183, -0.4272, -0.3973],\n",
- " [ 0.0095, -0.2703, 0.0542, ..., -0.8183, -0.4264, -0.3973],\n",
- " [ 0.4515, -0.4208, -0.0497, ..., -0.8187, -0.4256, -0.3972]],\n",
- "\n",
- " [[-0.2298, 0.5305, -1.6682, ..., -0.9803, 0.4137, -0.7174],\n",
- " [-1.2319, -0.1384, -1.7214, ..., -0.9782, 0.4137, -0.7245],\n",
- " [-1.0939, -0.7782, -1.6090, ..., -0.9785, 0.4124, -0.7308],\n",
- " ...,\n",
- " [ 1.3048, 0.8945, 0.3699, ..., -0.9897, 0.3753, -0.7444],\n",
- " [ 0.9834, 1.2249, 0.8801, ..., -0.9897, 0.3753, -0.7445],\n",
- " [-0.2414, 1.1782, 1.3028, ..., -0.9894, 0.3752, -0.7445]],\n",
- "\n",
- " [[-0.1399, 0.4728, 0.7512, ..., -0.4790, 0.0188, 1.0488],\n",
- " [ 0.3772, 0.0967, 0.6151, ..., -0.4738, 0.0168, 1.0455],\n",
- " [ 0.5529, -0.3050, 0.4153, ..., -0.4704, 0.0122, 1.0429],\n",
- " ...,\n",
- " [-0.4351, 0.3780, -0.6546, ..., -0.5326, 0.0208, 1.0425],\n",
- " [-0.5327, 0.6673, -0.7752, ..., -0.5330, 0.0208, 1.0421],\n",
- " [-0.1395, 0.7507, -0.8185, ..., -0.5332, 0.0209, 1.0415]],\n",
- "\n",
- " [[ 0.4693, -0.3338, 0.6982, ..., 0.4498, 1.1733, -0.4666],\n",
- " [ 0.9968, -0.1894, 0.7689, ..., 0.4459, 1.1737, -0.4665],\n",
- " [ 0.6078, 0.0178, 0.7619, ..., 0.4457, 1.1731, -0.4681],\n",
- " ...,\n",
- " [-0.9903, -0.0882, -0.0023, ..., 0.4639, 1.1550, -0.4838],\n",
- " [-0.4527, -0.2828, -0.2390, ..., 0.4640, 1.1545, -0.4840],\n",
- " [ 0.5021, -0.3899, -0.4519, ..., 0.4636, 1.1543, -0.4840]]]])\n",
- "tensor([[[[-0.5996, -0.8818, 1.0916, ..., 1.3350, -0.6233, -0.4210],\n",
- " [-0.5956, -0.8787, 1.0899, ..., 1.3421, -0.6181, -0.4139],\n",
- " [-0.5935, -0.8731, 1.0919, ..., 1.3472, -0.6063, -0.4136],\n",
- " ...,\n",
- " [-0.6075, -0.8807, 1.0738, ..., 1.3215, -0.5991, -0.4371],\n",
- " [-0.6073, -0.8811, 1.0738, ..., 1.3200, -0.5991, -0.4377],\n",
- " [-0.6068, -0.8812, 1.0731, ..., 1.3187, -0.5991, -0.4378]],\n",
- "\n",
- " [[-0.5803, -0.4059, -0.0754, ..., 0.6803, -0.3709, -0.1859],\n",
- " [-0.5851, -0.3976, -0.0814, ..., 0.6705, -0.3691, -0.1860],\n",
- " [-0.5859, -0.3901, -0.0849, ..., 0.6577, -0.3676, -0.1925],\n",
- " ...,\n",
- " [-0.5888, -0.4312, -0.0346, ..., 0.6210, -0.3170, -0.2286],\n",
- " [-0.5888, -0.4318, -0.0342, ..., 0.6211, -0.3165, -0.2287],\n",
- " [-0.5889, -0.4323, -0.0338, ..., 0.6214, -0.3162, -0.2285]],\n",
- "\n",
- " [[-1.3079, 0.6818, 0.7237, ..., -0.2953, 1.0035, 0.0144],\n",
- " [-1.3053, 0.7041, 0.7281, ..., -0.3037, 1.0063, 0.0130],\n",
- " [-1.3047, 0.7197, 0.7334, ..., -0.3116, 1.0118, 0.0097],\n",
- " ...,\n",
- " [-1.3425, 0.6961, 0.7392, ..., -0.2939, 1.0508, 0.0380],\n",
- " [-1.3430, 0.6959, 0.7389, ..., -0.2938, 1.0508, 0.0380],\n",
- " [-1.3436, 0.6961, 0.7390, ..., -0.2936, 1.0504, 0.0379]],\n",
- "\n",
- " [[-0.1651, -0.4341, -0.5129, ..., -0.2994, 0.5808, 0.0624],\n",
- " [-0.1675, -0.4243, -0.5142, ..., -0.3024, 0.5866, 0.0570],\n",
- " [-0.1647, -0.4233, -0.5158, ..., -0.2985, 0.5921, 0.0548],\n",
- " ...,\n",
- " [-0.1671, -0.4497, -0.4654, ..., -0.3217, 0.5996, 0.0394],\n",
- " [-0.1665, -0.4494, -0.4648, ..., -0.3218, 0.5996, 0.0393],\n",
- " [-0.1666, -0.4481, -0.4641, ..., -0.3226, 0.5996, 0.0394]]]])\n",
- "tensor([[[[ 0.7548, -0.7519, 0.9516, ..., -0.4464, -0.3636, -0.2868],\n",
- " [-0.1622, -0.7571, 0.5601, ..., -0.4323, -0.3665, -0.2852],\n",
- " [-0.9178, -0.5404, 0.1087, ..., -0.4212, -0.3705, -0.2827],\n",
- " ...,\n",
- " [ 0.3565, 0.4814, -1.4442, ..., -0.4623, -0.3604, -0.2805],\n",
- " [ 0.9851, 0.0684, -1.4689, ..., -0.4634, -0.3601, -0.2801],\n",
- " [ 0.7091, -0.3652, -1.3481, ..., -0.4641, -0.3598, -0.2800]],\n",
- "\n",
- " [[ 0.0558, -0.0961, -0.3113, ..., 0.3085, 0.3596, 0.7701],\n",
- " [ 0.0225, -0.3077, -0.4186, ..., 0.3083, 0.3633, 0.7737],\n",
- " [-0.0396, -0.4271, -0.4848, ..., 0.3086, 0.3699, 0.7804],\n",
- " ...,\n",
- " [ 0.0266, 0.4086, -0.2584, ..., 0.3113, 0.3957, 0.7807],\n",
- " [ 0.0362, 0.4019, -0.1022, ..., 0.3113, 0.3957, 0.7810],\n",
- " [ 0.0117, 0.2716, 0.0645, ..., 0.3113, 0.3955, 0.7811]],\n",
- "\n",
- " [[-0.0074, -0.1536, 0.3226, ..., 0.6790, -0.0736, 0.4578],\n",
- " [ 0.0267, -0.2415, 0.4238, ..., 0.6830, -0.0770, 0.4544],\n",
- " [ 0.0347, -0.2485, 0.4886, ..., 0.6840, -0.0802, 0.4572],\n",
- " ...,\n",
- " [-0.0258, 0.2343, 0.1960, ..., 0.7261, -0.0211, 0.5036],\n",
- " [-0.0052, 0.1651, 0.0384, ..., 0.7260, -0.0208, 0.5045],\n",
- " [ 0.0203, 0.0448, -0.1227, ..., 0.7264, -0.0205, 0.5050]],\n",
- "\n",
- " [[ 0.3598, 0.9917, 0.2596, ..., -0.6748, 0.2234, 0.4948],\n",
- " [-0.3283, 0.9644, 0.2100, ..., -0.6691, 0.2088, 0.4965],\n",
- " [-0.7103, 0.6471, 0.1438, ..., -0.6655, 0.1964, 0.4979],\n",
- " ...,\n",
- " [ 0.4529, -0.5427, -0.1871, ..., -0.7174, 0.2235, 0.4909],\n",
- " [ 0.7765, 0.0142, -0.2525, ..., -0.7171, 0.2236, 0.4911],\n",
- " [ 0.3870, 0.5668, -0.2926, ..., -0.7172, 0.2239, 0.4913]]]])\n",
- "tensor([[[[-0.4846, -0.0448, 0.7438, ..., -0.2732, 0.5203, -0.3579],\n",
- " [-0.4858, -0.0525, 0.7421, ..., -0.2585, 0.5151, -0.3524],\n",
- " [-0.4920, -0.0616, 0.7417, ..., -0.2429, 0.5112, -0.3484],\n",
- " ...,\n",
- " [-0.5254, -0.1285, 0.7435, ..., -0.2306, 0.5152, -0.3257],\n",
- " [-0.5255, -0.1289, 0.7432, ..., -0.2314, 0.5157, -0.3253],\n",
- " [-0.5254, -0.1291, 0.7429, ..., -0.2321, 0.5158, -0.3250]],\n",
- "\n",
- " [[-0.4301, 0.8671, 0.5170, ..., -1.1632, -0.3734, -1.1874],\n",
- " [-0.4446, 0.8762, 0.5196, ..., -1.1528, -0.3724, -1.1955],\n",
- " [-0.4582, 0.8844, 0.5180, ..., -1.1396, -0.3729, -1.1942],\n",
- " ...,\n",
- " [-0.4770, 0.8541, 0.5325, ..., -1.1392, -0.3525, -1.1223],\n",
- " [-0.4766, 0.8535, 0.5316, ..., -1.1390, -0.3523, -1.1217],\n",
- " [-0.4760, 0.8530, 0.5309, ..., -1.1390, -0.3519, -1.1214]],\n",
- "\n",
- " [[-0.6542, -0.1481, 0.0284, ..., 0.1807, 0.0668, -0.1159],\n",
- " [-0.6490, -0.1293, 0.0320, ..., 0.1974, 0.0622, -0.1280],\n",
- " [-0.6482, -0.1144, 0.0331, ..., 0.2133, 0.0589, -0.1325],\n",
- " ...,\n",
- " [-0.6662, -0.0985, 0.0627, ..., 0.1447, 0.0864, -0.0778],\n",
- " [-0.6664, -0.0994, 0.0632, ..., 0.1434, 0.0864, -0.0774],\n",
- " [-0.6662, -0.0997, 0.0639, ..., 0.1421, 0.0864, -0.0775]],\n",
- "\n",
- " [[-0.1255, 1.4525, -0.4119, ..., 0.6365, 0.6966, 0.2816],\n",
- " [-0.1274, 1.4516, -0.4094, ..., 0.6396, 0.6962, 0.2907],\n",
- " [-0.1325, 1.4417, -0.4089, ..., 0.6359, 0.6969, 0.3028],\n",
- " ...,\n",
- " [-0.1968, 1.3976, -0.3989, ..., 0.6016, 0.7498, 0.3160],\n",
- " [-0.1972, 1.3966, -0.3989, ..., 0.6020, 0.7502, 0.3163],\n",
- " [-0.1977, 1.3958, -0.3984, ..., 0.6025, 0.7504, 0.3163]]]])\n"
- ]
- },
- {
- "data": {
- "text/plain": [
- "====================================================================================================\n",
- "Layer (type:depth-idx) Output Shape Param #\n",
- "====================================================================================================\n",
- "ConvTransformer -- --\n",
- "├─EfficientNet: 1-1 [1, 1280, 1, 32] 7,142,272\n",
- "├─Sequential: 1-2 [1, 128, 32] 163,968\n",
- "├─Embedding: 1-3 [1, 89, 128] 7,424\n",
- "├─PositionalEncoding: 1-4 [1, 89, 128] --\n",
- "├─Decoder: 1-5 [1, 89, 128] 13,176,969\n",
- "├─Linear: 1-6 [1, 89, 58] 7,482\n",
- "====================================================================================================\n",
- "Total params: 20,498,115\n",
- "Trainable params: 20,498,115\n",
- "Non-trainable params: 0\n",
- "Total mult-adds (M): 714.86\n",
- "====================================================================================================\n",
- "Input size (MB): 0.23\n",
- "Forward/backward pass size (MB): 184.29\n",
- "Params size (MB): 81.99\n",
- "Estimated Total Size (MB): 266.51\n",
- "===================================================================================================="
- ]
- },
- "execution_count": 16,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
+ "outputs": [],
"source": [
"summary(net, list(map(lambda x: list(x), cfg.summary)), device=\"cpu\", depth=1)"
]
@@ -336,7 +102,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "bf9d8d67-d7d2-4cf7-b166-377a79d5fd70",
+ "id": "4b1fe971-2a08-4010-855a-7971067cc559",
"metadata": {},
"outputs": [],
"source": []