summaryrefslogtreecommitdiff
path: root/training/conf/experiment/vq_htr_char.yaml
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-19 21:05:51 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-19 21:05:51 +0200
commit53dce464914b2ce3817a7cec9441f7cfa5048248 (patch)
tree18688185b9ee158a7f80aef9d17bb59221ca07ee /training/conf/experiment/vq_htr_char.yaml
parent6dd33e5a087159dcbabb845f167279778b2a8ea5 (diff)
Update to experiments configs
Diffstat (limited to 'training/conf/experiment/vq_htr_char.yaml')
-rw-r--r--training/conf/experiment/vq_htr_char.yaml74
1 files changed, 0 insertions, 74 deletions
diff --git a/training/conf/experiment/vq_htr_char.yaml b/training/conf/experiment/vq_htr_char.yaml
deleted file mode 100644
index b34dd11..0000000
--- a/training/conf/experiment/vq_htr_char.yaml
+++ /dev/null
@@ -1,74 +0,0 @@
-# @package _global_
-
-defaults:
- - override /mapping: null
- - override /network: null
- - override /model: null
-
-mapping:
- _target_: text_recognizer.data.emnist_mapping.EmnistMapping
- extra_symbols: [ "\n" ]
-
-datamodule:
- word_pieces: false
- batch_size: 8
-
-criterion:
- ignore_index: 3
-
-network:
- _target_: text_recognizer.networks.vq_transformer.VqTransformer
- input_dims: [1, 576, 640]
- encoder_dim: 64
- hidden_dim: 64
- dropout_rate: 0.1
- num_classes: 58
- pad_index: 3
- no_grad: false
- encoder:
- _target_: text_recognizer.networks.vqvae.vqvae.VQVAE
- hidden_dim: 128
- embedding_dim: 64
- num_embeddings: 1024
- decay: 0.99
- encoder:
- _target_: text_recognizer.networks.vqvae.encoder.Encoder
- in_channels: 1
- hidden_dim: 64
- channels_multipliers: [1, 1, 2, 2]
- dropout_rate: 0.0
- decoder:
- _target_: text_recognizer.networks.vqvae.decoder.Decoder
- out_channels: 1
- hidden_dim: 64
- channels_multipliers: [2, 2, 1, 1]
- dropout_rate: 0.0
- decoder:
- _target_: text_recognizer.networks.transformer.Decoder
- dim: 64
- depth: 2
- num_heads: 4
- attn_fn: text_recognizer.networks.transformer.attention.Attention
- attn_kwargs:
- dim_head: 32
- dropout_rate: 0.2
- norm_fn: torch.nn.LayerNorm
- ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
- ff_kwargs:
- dim_out: null
- expansion_factor: 4
- glu: true
- dropout_rate: 0.2
- cross_attend: true
- pre_norm: true
- rotary_emb: null
-
- # pretrained_encoder_path: "training/logs/runs/2021-09-13/08-35-57/checkpoints/epoch=98.ckpt"
-
-model:
- _target_: text_recognizer.models.vq_transformer.VqTransformerLitModel
- start_token: <s>
- end_token: <e>
- pad_token: <p>
- max_output_len: 682
- # max_output_len: 451