summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-25 23:32:50 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-25 23:32:50 +0200
commit9426cc794d8c28a65bbbf5ae5466a0a343078558 (patch)
tree44e31b0a7c58597d603ac29a693462aae4b6e9b0 /training
parent4e60c836fb710baceba570c28c06437db3ad5c9b (diff)
Efficient net and non working transformer model.
Diffstat (limited to 'training')
-rw-r--r--training/configs/cnn_transformer.yaml (renamed from training/configs/image_transformer.yaml)41
-rw-r--r--training/configs/vqvae.yaml10
2 files changed, 26 insertions, 25 deletions
diff --git a/training/configs/image_transformer.yaml b/training/configs/cnn_transformer.yaml
index e6637f2..a4f16df 100644
--- a/training/configs/image_transformer.yaml
+++ b/training/configs/cnn_transformer.yaml
@@ -2,12 +2,13 @@ seed: 4711
network:
desc: Configuration of the PyTorch neural network.
- type: ImageTransformer
+ type: CNNTransformer
args:
encoder:
- type: null
+ type: EfficientNet
args: null
num_decoder_layers: 4
+ vocab_size: 84
hidden_dim: 256
num_heads: 4
expansion_dim: 1024
@@ -26,7 +27,7 @@ model:
weight_decay: 0
eps: 1.0e-6
lr_scheduler:
- type: OneCycle
+ type: OneCycleLR
args:
interval: &interval step
max_lr: 1.0e-3
@@ -36,7 +37,7 @@ model:
criterion:
type: CrossEntropyLoss
args:
- weight: None
+ weight: null
ignore_index: -100
reduction: mean
monitor: val_loss
@@ -46,7 +47,7 @@ data:
desc: Configuration of the training/test data.
type: IAMExtendedParagraphs
args:
- batch_size: 16
+ batch_size: 8
num_workers: 12
train_fraction: 0.8
augment: true
@@ -57,33 +58,33 @@ callbacks:
monitor: val_loss
mode: min
save_last: true
- - type: StochasticWeightAveraging
- args:
- swa_epoch_start: 0.8
- swa_lrs: 0.05
- annealing_epochs: 10
- annealing_strategy: cos
- device: null
+ # - type: StochasticWeightAveraging
+ # args:
+ # swa_epoch_start: 0.8
+ # swa_lrs: 0.05
+ # annealing_epochs: 10
+ # annealing_strategy: cos
+ # device: null
- type: LearningRateMonitor
args:
logging_interval: *interval
- - type: EarlyStopping
- args:
- monitor: val_loss
- mode: min
- patience: 10
+ # - type: EarlyStopping
+ # args:
+ # monitor: val_loss
+ # mode: min
+ # patience: 10
trainer:
desc: Configuration of the PyTorch Lightning Trainer.
args:
- stochastic_weight_avg: true
+ stochastic_weight_avg: false
auto_scale_batch_size: binsearch
gradient_clip_val: 0
- fast_dev_run: false
+ fast_dev_run: true
gpus: 1
precision: 16
max_epochs: 512
terminate_on_nan: true
- weights_summary: true
+ weights_summary: top
load_checkpoint: null
diff --git a/training/configs/vqvae.yaml b/training/configs/vqvae.yaml
index a7acb3a..13d7c97 100644
--- a/training/configs/vqvae.yaml
+++ b/training/configs/vqvae.yaml
@@ -5,12 +5,12 @@ network:
type: VQVAE
args:
in_channels: 1
- channels: [32, 64, 64]
- kernel_sizes: [4, 4, 4]
- strides: [2, 2, 2]
+ channels: [32, 64, 64, 96, 96]
+ kernel_sizes: [4, 4, 4, 4, 4]
+ strides: [2, 2, 2, 2, 2]
num_residual_layers: 2
- embedding_dim: 128
- num_embeddings: 512
+ embedding_dim: 512
+ num_embeddings: 1024
upsampling: null
beta: 0.25
activation: leaky_relu