summaryrefslogtreecommitdiff
path: root/training/conf/experiment/vq_htr_char.yaml
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-18 17:43:23 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-18 17:43:23 +0200
commit9ef2857c2d24d9c0a8fba3c5db58c7303124c79b (patch)
treedc7eb4a179b8cd706e39b650dd3d215bb667db85 /training/conf/experiment/vq_htr_char.yaml
parent0b8924f37fbab57a3d6f59421e9cd16421c9af4b (diff)
Update experiment configs
Diffstat (limited to 'training/conf/experiment/vq_htr_char.yaml')
-rw-r--r--training/conf/experiment/vq_htr_char.yaml74
1 files changed, 74 insertions, 0 deletions
diff --git a/training/conf/experiment/vq_htr_char.yaml b/training/conf/experiment/vq_htr_char.yaml
new file mode 100644
index 0000000..b34dd11
--- /dev/null
+++ b/training/conf/experiment/vq_htr_char.yaml
@@ -0,0 +1,74 @@
+# @package _global_
+
+defaults:
+ - override /mapping: null
+ - override /network: null
+ - override /model: null
+
+mapping:
+ _target_: text_recognizer.data.emnist_mapping.EmnistMapping
+ extra_symbols: [ "\n" ]
+
+datamodule:
+ word_pieces: false
+ batch_size: 8
+
+criterion:
+ ignore_index: 3
+
+network:
+ _target_: text_recognizer.networks.vq_transformer.VqTransformer
+ input_dims: [1, 576, 640]
+ encoder_dim: 64
+ hidden_dim: 64
+ dropout_rate: 0.1
+ num_classes: 58
+ pad_index: 3
+ no_grad: false
+ encoder:
+ _target_: text_recognizer.networks.vqvae.vqvae.VQVAE
+ hidden_dim: 128
+ embedding_dim: 64
+ num_embeddings: 1024
+ decay: 0.99
+ encoder:
+ _target_: text_recognizer.networks.vqvae.encoder.Encoder
+ in_channels: 1
+ hidden_dim: 64
+ channels_multipliers: [1, 1, 2, 2]
+ dropout_rate: 0.0
+ decoder:
+ _target_: text_recognizer.networks.vqvae.decoder.Decoder
+ out_channels: 1
+ hidden_dim: 64
+ channels_multipliers: [2, 2, 1, 1]
+ dropout_rate: 0.0
+ decoder:
+ _target_: text_recognizer.networks.transformer.Decoder
+ dim: 64
+ depth: 2
+ num_heads: 4
+ attn_fn: text_recognizer.networks.transformer.attention.Attention
+ attn_kwargs:
+ dim_head: 32
+ dropout_rate: 0.2
+ norm_fn: torch.nn.LayerNorm
+ ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
+ ff_kwargs:
+ dim_out: null
+ expansion_factor: 4
+ glu: true
+ dropout_rate: 0.2
+ cross_attend: true
+ pre_norm: true
+ rotary_emb: null
+
+ # pretrained_encoder_path: "training/logs/runs/2021-09-13/08-35-57/checkpoints/epoch=98.ckpt"
+
+model:
+ _target_: text_recognizer.models.vq_transformer.VqTransformerLitModel
+ start_token: <s>
+ end_token: <e>
+ pad_token: <p>
+ max_output_len: 682
+ # max_output_len: 451