summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-05 00:09:19 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-05 00:09:19 +0200
commit12abf17cd7c31ae4599be366505a4423fbba4044 (patch)
tree996e5d549ebbb7d22f5acfdcd321bddec77f98d1
parent16e2e420e077253c3b2bc414283281fea557717d (diff)
Update perceiver conf
-rw-r--r--training/conf/experiment/conv_perceiver_lines.yaml76
-rw-r--r--training/conf/network/conv_perceiver.yaml23
-rw-r--r--training/conf/network/conv_transformer.yaml2
3 files changed, 92 insertions, 9 deletions
diff --git a/training/conf/experiment/conv_perceiver_lines.yaml b/training/conf/experiment/conv_perceiver_lines.yaml
new file mode 100644
index 0000000..26fe232
--- /dev/null
+++ b/training/conf/experiment/conv_perceiver_lines.yaml
@@ -0,0 +1,76 @@
+# @package _global_
+
+defaults:
+ - override /criterion: cross_entropy
+ - override /callbacks: htr
+ - override /datamodule: iam_lines
+ - override /network: conv_perceiver
+ - override /model: lit_perceiver
+ - override /lr_scheduler: null
+ - override /optimizer: null
+
+tags: [lines, perceiver]
+epochs: &epochs 260
+ignore_index: &ignore_index 3
+num_classes: &num_classes 57
+max_output_len: &max_output_len 89
+summary: [[1, 1, 56, 1024]]
+
+logger:
+ wandb:
+ tags: ${tags}
+
+criterion:
+ ignore_index: *ignore_index
+ # label_smoothing: 0.1
+
+callbacks:
+ stochastic_weight_averaging:
+ _target_: pytorch_lightning.callbacks.StochasticWeightAveraging
+ swa_epoch_start: 0.75
+ swa_lrs: 1.0e-5
+ annealing_epochs: 10
+ annealing_strategy: cos
+ device: null
+
+optimizer:
+ _target_: adan_pytorch.Adan
+ lr: 1.0e-4
+ betas: [0.02, 0.08, 0.01]
+ weight_decay: 0.02
+
+lr_scheduler:
+ _target_: torch.optim.lr_scheduler.ReduceLROnPlateau
+ mode: min
+ factor: 0.8
+ patience: 10
+ threshold: 1.0e-4
+ threshold_mode: rel
+ cooldown: 0
+ min_lr: 1.0e-5
+ eps: 1.0e-8
+ verbose: false
+ interval: epoch
+ monitor: val/cer
+
+datamodule:
+ batch_size: 8
+ train_fraction: 0.95
+
+network:
+ input_dims: [1, 1, 56, 1024]
+ num_classes: *num_classes
+ pad_index: *ignore_index
+ encoder:
+ depth: 5
+ decoder:
+ depth: 6
+
+model:
+ max_output_len: *max_output_len
+
+trainer:
+ gradient_clip_val: 1.0
+ stochastic_weight_avg: true
+ max_epochs: *epochs
+ accumulate_grad_batches: 1
diff --git a/training/conf/network/conv_perceiver.yaml b/training/conf/network/conv_perceiver.yaml
index e6906fa..2e12db9 100644
--- a/training/conf/network/conv_perceiver.yaml
+++ b/training/conf/network/conv_perceiver.yaml
@@ -1,9 +1,10 @@
_target_: text_recognizer.networks.ConvPerceiver
input_dims: [1, 1, 576, 640]
-hidden_dim: &hidden_dim 144
+hidden_dim: &hidden_dim 128
num_classes: &num_classes 58
-queries_dim: &queries_dim 16
-max_length: 89
+max_length: &max_length 89
+num_queries: *max_length
+queries_dim: &queries_dim 64
pad_index: 3
encoder:
_target_: text_recognizer.networks.EfficientNet
@@ -15,16 +16,22 @@ encoder:
out_channels: *hidden_dim
decoder:
_target_: text_recognizer.networks.perceiver.PerceiverIO
- dim: *hidden_dim
+ dim: 192
cross_heads: 1
cross_head_dim: 64
num_latents: 256
latent_dim: 512
latent_heads: 8
depth: 6
- queries_dim: *queries_dim
+ queries_dim: 128
logits_dim: *num_classes
pixel_embedding:
- _target_: text_recognizer.networks.transformer.AxialPositionalEmbedding
- dim: *hidden_dim
- shape: [3, 64]
+ _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbeddingImage
+ dim: 64
+ axial_shape: [3, 64]
+ axial_dims: [32, 32]
+query_pos_emb:
+ _target_: text_recognizer.networks.transformer.embeddings.absolute.AbsolutePositionalEmbedding
+ dim: 64
+ max_seq_len: *max_length
+ l2norm_embed: true
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
index c4cf13e..f618ba1 100644
--- a/training/conf/network/conv_transformer.yaml
+++ b/training/conf/network/conv_transformer.yaml
@@ -1,6 +1,6 @@
_target_: text_recognizer.networks.ConvTransformer
input_dims: [1, 1, 576, 640]
-hidden_dim: &hidden_dim 96
+hidden_dim: &hidden_dim 144
num_classes: 58
pad_index: 3
encoder: