summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-24 23:09:20 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-24 23:09:20 +0200
commit4e60c836fb710baceba570c28c06437db3ad5c9b (patch)
tree21caf6d1792bd83a47fb3d372ee7120211e83f18 /training
parent1ca8b0b9e0613c1e02f6a5d8b49e20c4d6916412 (diff)
Implementing CoaT transformer, continue tomorrow...
Diffstat (limited to 'training')
-rw-r--r--training/configs/vqvae.yaml44
-rw-r--r--training/run_experiment.py2
2 files changed, 23 insertions, 23 deletions
diff --git a/training/configs/vqvae.yaml b/training/configs/vqvae.yaml
index 90082f7..a7acb3a 100644
--- a/training/configs/vqvae.yaml
+++ b/training/configs/vqvae.yaml
@@ -5,16 +5,16 @@ network:
type: VQVAE
args:
in_channels: 1
- channels: [32, 64, 96]
+ channels: [32, 64, 64]
kernel_sizes: [4, 4, 4]
strides: [2, 2, 2]
num_residual_layers: 2
- embedding_dim: 64
- num_embeddings: 1024
+ embedding_dim: 128
+ num_embeddings: 512
upsampling: null
beta: 0.25
activation: leaky_relu
- dropout_rate: 0.1
+ dropout_rate: 0.2
model:
desc: Configuration of the PyTorch Lightning model.
@@ -33,8 +33,8 @@ model:
interval: &interval step
max_lr: 1.0e-3
three_phase: true
- epochs: 512
- steps_per_epoch: 317 # num_samples / batch_size
+ epochs: 64
+ steps_per_epoch: 633 # num_samples / batch_size
criterion:
type: MSELoss
args:
@@ -46,7 +46,7 @@ data:
desc: Configuration of the training/test data.
type: IAMExtendedParagraphs
args:
- batch_size: 64
+ batch_size: 32
num_workers: 12
train_fraction: 0.8
augment: true
@@ -57,33 +57,33 @@ callbacks:
monitor: val_loss
mode: min
save_last: true
- # - type: StochasticWeightAveraging
- # args:
- # swa_epoch_start: 0.8
- # swa_lrs: 0.05
- # annealing_epochs: 10
- # annealing_strategy: cos
- # device: null
+ - type: StochasticWeightAveraging
+ args:
+ swa_epoch_start: 0.8
+ swa_lrs: 0.05
+ annealing_epochs: 10
+ annealing_strategy: cos
+ device: null
- type: LearningRateMonitor
args:
logging_interval: *interval
- - type: EarlyStopping
- args:
- monitor: val_loss
- mode: min
- patience: 10
+ # - type: EarlyStopping
+ # args:
+ # monitor: val_loss
+ # mode: min
+ # patience: 10
trainer:
desc: Configuration of the PyTorch Lightning Trainer.
args:
- stochastic_weight_avg: false # true
+ stochastic_weight_avg: true
auto_scale_batch_size: binsearch
gradient_clip_val: 0
fast_dev_run: false
gpus: 1
precision: 16
- max_epochs: 512
+ max_epochs: 64
terminate_on_nan: true
- weights_summary: full
+ weights_summary: top
load_checkpoint: null
diff --git a/training/run_experiment.py b/training/run_experiment.py
index e1aae4e..bdefbf0 100644
--- a/training/run_experiment.py
+++ b/training/run_experiment.py
@@ -22,7 +22,7 @@ def _create_experiment_dir(config: DictConfig) -> Path:
"""Creates log directory for experiment."""
log_dir = (
LOGS_DIRNAME
- / f"{config.model.type}_{config.network.type}"
+ / f"{config.model.type}_{config.network.type}".lower()
/ datetime.now().strftime("%m%d_%H%M%S")
)
log_dir.mkdir(parents=True, exist_ok=True)