summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/conf/config.yaml10
-rw-r--r--training/conf/datamodule/iam_extended_paragraphs.yaml1
-rw-r--r--training/conf/experiment/htr_char.yaml17
-rw-r--r--training/conf/experiment/vq_htr_char.yaml74
-rw-r--r--training/conf/experiment/vqgan.yaml36
-rw-r--r--training/conf/experiment/vqvae.yaml38
-rw-r--r--training/conf/model/lit_transformer.yaml2
-rw-r--r--training/conf/model/lit_vqvae.yaml1
-rw-r--r--training/conf/network/conv_transformer.yaml1
-rw-r--r--training/conf/network/decoder/vae_decoder.yaml5
-rw-r--r--training/conf/network/encoder/vae_encoder.yaml5
-rw-r--r--training/conf/network/vqvae.yaml6
-rw-r--r--training/conf/optimizers/madgrad.yaml2
13 files changed, 147 insertions, 51 deletions
diff --git a/training/conf/config.yaml b/training/conf/config.yaml
index 5897d87..9ed366f 100644
--- a/training/conf/config.yaml
+++ b/training/conf/config.yaml
@@ -7,8 +7,8 @@ defaults:
- hydra: default
- logger: wandb
- lr_schedulers:
- - one_cycle
- - mapping: word_piece
+ - cosine_annealing
+ - mapping: characters # word_piece
- model: lit_transformer
- network: conv_transformer
- optimizers:
@@ -21,6 +21,12 @@ train: true
test: true
logging: INFO
+# datamodule:
+# word_pieces: false
+
+# model:
+# max_output_len: 682
+
# path to original working directory
# hydra hijacks working directory by changing it to the current log directory,
# so it's useful to have this path as a special variable
diff --git a/training/conf/datamodule/iam_extended_paragraphs.yaml b/training/conf/datamodule/iam_extended_paragraphs.yaml
index a2dd293..a0ffe56 100644
--- a/training/conf/datamodule/iam_extended_paragraphs.yaml
+++ b/training/conf/datamodule/iam_extended_paragraphs.yaml
@@ -5,3 +5,4 @@ train_fraction: 0.8
augment: true
pin_memory: false
word_pieces: true
+resize: null
diff --git a/training/conf/experiment/htr_char.yaml b/training/conf/experiment/htr_char.yaml
deleted file mode 100644
index e51a116..0000000
--- a/training/conf/experiment/htr_char.yaml
+++ /dev/null
@@ -1,17 +0,0 @@
-# @package _global_
-
-defaults:
- - override /mapping: characters
-
-datamodule:
- word_pieces: false
-
-criterion:
- ignore_index: 3
-
-network:
- num_classes: 58
- pad_index: 3
-
-model:
- max_output_len: 682
diff --git a/training/conf/experiment/vq_htr_char.yaml b/training/conf/experiment/vq_htr_char.yaml
new file mode 100644
index 0000000..b34dd11
--- /dev/null
+++ b/training/conf/experiment/vq_htr_char.yaml
@@ -0,0 +1,74 @@
+# @package _global_
+
+defaults:
+ - override /mapping: null
+ - override /network: null
+ - override /model: null
+
+mapping:
+ _target_: text_recognizer.data.emnist_mapping.EmnistMapping
+ extra_symbols: [ "\n" ]
+
+datamodule:
+ word_pieces: false
+ batch_size: 8
+
+criterion:
+ ignore_index: 3
+
+network:
+ _target_: text_recognizer.networks.vq_transformer.VqTransformer
+ input_dims: [1, 576, 640]
+ encoder_dim: 64
+ hidden_dim: 64
+ dropout_rate: 0.1
+ num_classes: 58
+ pad_index: 3
+ no_grad: false
+ encoder:
+ _target_: text_recognizer.networks.vqvae.vqvae.VQVAE
+ hidden_dim: 128
+ embedding_dim: 64
+ num_embeddings: 1024
+ decay: 0.99
+ encoder:
+ _target_: text_recognizer.networks.vqvae.encoder.Encoder
+ in_channels: 1
+ hidden_dim: 64
+ channels_multipliers: [1, 1, 2, 2]
+ dropout_rate: 0.0
+ decoder:
+ _target_: text_recognizer.networks.vqvae.decoder.Decoder
+ out_channels: 1
+ hidden_dim: 64
+ channels_multipliers: [2, 2, 1, 1]
+ dropout_rate: 0.0
+ decoder:
+ _target_: text_recognizer.networks.transformer.Decoder
+ dim: 64
+ depth: 2
+ num_heads: 4
+ attn_fn: text_recognizer.networks.transformer.attention.Attention
+ attn_kwargs:
+ dim_head: 32
+ dropout_rate: 0.2
+ norm_fn: torch.nn.LayerNorm
+ ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
+ ff_kwargs:
+ dim_out: null
+ expansion_factor: 4
+ glu: true
+ dropout_rate: 0.2
+ cross_attend: true
+ pre_norm: true
+ rotary_emb: null
+
+ # pretrained_encoder_path: "training/logs/runs/2021-09-13/08-35-57/checkpoints/epoch=98.ckpt"
+
+model:
+ _target_: text_recognizer.models.vq_transformer.VqTransformerLitModel
+ start_token: <s>
+ end_token: <e>
+ pad_token: <p>
+ max_output_len: 682
+ # max_output_len: 451
diff --git a/training/conf/experiment/vqgan.yaml b/training/conf/experiment/vqgan.yaml
index 9224bc7..6c78deb 100644
--- a/training/conf/experiment/vqgan.yaml
+++ b/training/conf/experiment/vqgan.yaml
@@ -2,7 +2,7 @@
defaults:
- override /network: vqvae
- - override /criterion: vqgan_loss
+ - override /criterion: null
- override /model: lit_vqgan
- override /callbacks: wandb_vae
- override /optimizers: null
@@ -11,7 +11,7 @@ defaults:
criterion:
_target_: text_recognizer.criterions.vqgan_loss.VQGANLoss
reconstruction_loss:
- _target_: torch.nn.L1Loss
+ _target_: torch.nn.MSELoss
reduction: mean
discriminator:
_target_: text_recognizer.criterions.n_layer_discriminator.NLayerDiscriminator
@@ -21,35 +21,41 @@ criterion:
vq_loss_weight: 0.25
discriminator_weight: 1.0
discriminator_factor: 1.0
- discriminator_iter_start: 2.0e4
+ discriminator_iter_start: 5e2
datamodule:
- batch_size: 6
+ batch_size: 8
+ resize: [288, 320]
-lr_schedulers: null
+lr_schedulers:
+ generator:
+ _target_: torch.optim.lr_scheduler.CosineAnnealingLR
+ T_max: 128
+ eta_min: 4.5e-6
+ last_epoch: -1
-# lr_schedulers:
-# generator:
+ interval: epoch
+ monitor: val/loss
# _target_: torch.optim.lr_scheduler.OneCycleLR
# max_lr: 3.0e-4
# total_steps: null
# epochs: 100
-# steps_per_epoch: 3369
+# steps_per_epoch: 2496
# pct_start: 0.1
# anneal_strategy: cos
# cycle_momentum: true
# base_momentum: 0.85
# max_momentum: 0.95
-# div_factor: 1.0e3
+# div_factor: 25
# final_div_factor: 1.0e4
# three_phase: true
# last_epoch: -1
# verbose: false
-#
+
# # Non-class arguments
# interval: step
# monitor: val/loss
-#
+
# discriminator:
# _target_: torch.optim.lr_scheduler.CosineAnnealingLR
# T_max: 64
@@ -79,7 +85,7 @@ optimizers:
parameters: loss_fn.discriminator
trainer:
- max_epochs: 64
- # gradient_clip_val: 1.0e1
-
-summary: null
+ max_epochs: 128
+ limit_train_batches: 0.1
+ limit_val_batches: 0.1
+ # gradient_clip_val: 100
diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml
index d3db471..d9fa2c4 100644
--- a/training/conf/experiment/vqvae.yaml
+++ b/training/conf/experiment/vqvae.yaml
@@ -2,26 +2,52 @@
defaults:
- override /network: vqvae
- - override /criterion: mae
+ - override /criterion: mse
- override /model: lit_vqvae
- override /callbacks: wandb_vae
- - override /lr_schedulers:
- - cosine_annealing
+ - override /optimizers: null
+ # - override /lr_schedulers:
+ # - cosine_annealing
+
+# lr_schedulers: null
+# network:
+# _target_: torch.optim.lr_scheduler.OneCycleLR
+# max_lr: 1.0e-2
+# total_steps: null
+# epochs: 100
+# steps_per_epoch: 200
+# pct_start: 0.1
+# anneal_strategy: cos
+# cycle_momentum: true
+# base_momentum: 0.85
+# max_momentum: 0.95
+# div_factor: 25
+# final_div_factor: 1.0e4
+# three_phase: true
+# last_epoch: -1
+# verbose: false
+
+# # Non-class arguments
+# interval: step
+# monitor: val/loss
optimizers:
network:
_target_: madgrad.MADGRAD
- lr: 3.0e-4
+ lr: 1.0e-4
momentum: 0.9
weight_decay: 0
- eps: 1.0e-6
+ eps: 1.0e-7
parameters: network
trainer:
- max_epochs: 256
+ max_epochs: 128
+ limit_train_batches: 0.01
+ limit_val_batches: 0.1
datamodule:
batch_size: 8
+ # resize: [288, 320]
summary: null
diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml
index c190151..0ec3b8a 100644
--- a/training/conf/model/lit_transformer.yaml
+++ b/training/conf/model/lit_transformer.yaml
@@ -1,6 +1,4 @@
_target_: text_recognizer.models.transformer.TransformerLitModel
-interval: step
-monitor: val/loss
max_output_len: 451
start_token: <s>
end_token: <e>
diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml
index 632668b..6dc44d7 100644
--- a/training/conf/model/lit_vqvae.yaml
+++ b/training/conf/model/lit_vqvae.yaml
@@ -1,2 +1 @@
_target_: text_recognizer.models.vqvae.VQVAELitModel
-latent_loss_weight: 0.25
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
index d3a3b0f..1d61129 100644
--- a/training/conf/network/conv_transformer.yaml
+++ b/training/conf/network/conv_transformer.yaml
@@ -5,6 +5,7 @@ defaults:
_target_: text_recognizer.networks.conv_transformer.ConvTransformer
input_dims: [1, 576, 640]
hidden_dim: 128
+encoder_dim: 1280
dropout_rate: 0.2
num_classes: 1006
pad_index: 1002
diff --git a/training/conf/network/decoder/vae_decoder.yaml b/training/conf/network/decoder/vae_decoder.yaml
index 60cdcf1..2053544 100644
--- a/training/conf/network/decoder/vae_decoder.yaml
+++ b/training/conf/network/decoder/vae_decoder.yaml
@@ -1,5 +1,6 @@
_target_: text_recognizer.networks.vqvae.decoder.Decoder
out_channels: 1
-hidden_dim: 64
-channels_multipliers: [8, 4, 2, 1]
+hidden_dim: 32
+channels_multipliers: [4, 2, 1]
dropout_rate: 0.0
+activation: leaky_relu
diff --git a/training/conf/network/encoder/vae_encoder.yaml b/training/conf/network/encoder/vae_encoder.yaml
index 73529fc..2ea3adf 100644
--- a/training/conf/network/encoder/vae_encoder.yaml
+++ b/training/conf/network/encoder/vae_encoder.yaml
@@ -1,5 +1,6 @@
_target_: text_recognizer.networks.vqvae.encoder.Encoder
in_channels: 1
-hidden_dim: 64
-channels_multipliers: [1, 2, 4, 8]
+hidden_dim: 32
+channels_multipliers: [1, 2, 4]
dropout_rate: 0.0
+activation: leaky_relu
diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml
index 70d27d7..d97e9b6 100644
--- a/training/conf/network/vqvae.yaml
+++ b/training/conf/network/vqvae.yaml
@@ -3,7 +3,7 @@ defaults:
- decoder: vae_decoder
_target_: text_recognizer.networks.vqvae.vqvae.VQVAE
-hidden_dim: 512
-embedding_dim: 64
-num_embeddings: 4096
+hidden_dim: 128
+embedding_dim: 32
+num_embeddings: 1024
decay: 0.99
diff --git a/training/conf/optimizers/madgrad.yaml b/training/conf/optimizers/madgrad.yaml
index a6c059d..d97bf7e 100644
--- a/training/conf/optimizers/madgrad.yaml
+++ b/training/conf/optimizers/madgrad.yaml
@@ -1,6 +1,6 @@
madgrad:
_target_: madgrad.MADGRAD
- lr: 1.0e-3
+ lr: 3.0e-4
momentum: 0.9
weight_decay: 0
eps: 1.0e-6