summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/conf/callbacks/checkpoint.yaml12
-rw-r--r--training/conf/callbacks/early_stopping.yaml4
-rw-r--r--training/conf/callbacks/wandb.yaml6
-rw-r--r--training/conf/config.yaml5
-rw-r--r--training/conf/experiment/vqvae_experiment.yaml13
-rw-r--r--training/conf/logger/wandb.yaml8
-rw-r--r--training/conf/model/lit_transformer.yaml8
-rw-r--r--training/conf/model/mapping/word_piece.yaml (renamed from training/conf/mapping/word_piece.yaml)0
-rw-r--r--training/conf/network/conv_transformer.yaml8
-rw-r--r--training/conf/network/decoder/transformer_decoder.yaml7
-rw-r--r--training/conf/trainer/default.yaml27
11 files changed, 57 insertions, 41 deletions
diff --git a/training/conf/callbacks/checkpoint.yaml b/training/conf/callbacks/checkpoint.yaml
index 9216715..db34cb1 100644
--- a/training/conf/callbacks/checkpoint.yaml
+++ b/training/conf/callbacks/checkpoint.yaml
@@ -1,9 +1,9 @@
model_checkpoint:
_target_: pytorch_lightning.callbacks.ModelCheckpoint
- monitor: "val/loss" # name of the logged metric which determines when model is improving
+ monitor: val/loss # name of the logged metric which determines when model is improving
save_top_k: 1 # save k best models (determined by above metric)
- save_last: True # additionaly always save model from last epoch
- mode: "min" # can be "max" or "min"
- verbose: False
- dirpath: "checkpoints/"
- filename: "{epoch:02d}"
+ save_last: true # additionaly always save model from last epoch
+ mode: min # can be "max" or "min"
+ verbose: false
+ dirpath: checkpoints/
+ filename: {epoch:02d}
diff --git a/training/conf/callbacks/early_stopping.yaml b/training/conf/callbacks/early_stopping.yaml
index 4cd5aa1..a188df3 100644
--- a/training/conf/callbacks/early_stopping.yaml
+++ b/training/conf/callbacks/early_stopping.yaml
@@ -1,6 +1,6 @@
early_stopping:
_target_: pytorch_lightning.callbacks.EarlyStopping
- monitor: "val/loss" # name of the logged metric which determines when model is improving
+ monitor: val/loss # name of the logged metric which determines when model is improving
patience: 16 # how many epochs of not improving until training stops
- mode: "min" # can be "max" or "min"
+ mode: min # can be "max" or "min"
min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement
diff --git a/training/conf/callbacks/wandb.yaml b/training/conf/callbacks/wandb.yaml
index 6eedb71..0017e11 100644
--- a/training/conf/callbacks/wandb.yaml
+++ b/training/conf/callbacks/wandb.yaml
@@ -3,7 +3,7 @@ defaults:
watch_model:
_target_: callbacks.wandb_callbacks.WatchModel
- log: "all"
+ log: all
log_freq: 100
upload_code_as_artifact:
@@ -12,8 +12,8 @@ upload_code_as_artifact:
upload_ckpts_as_artifact:
_target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact
- ckpt_dir: "checkpoints/"
- upload_best_only: True
+ ckpt_dir: checkpoints/
+ upload_best_only: true
log_text_predictions:
_target_: callbacks.wandb_callbacks.LogTextPredictions
diff --git a/training/conf/config.yaml b/training/conf/config.yaml
index b43e375..a8e718e 100644
--- a/training/conf/config.yaml
+++ b/training/conf/config.yaml
@@ -10,5 +10,10 @@ defaults:
- checkpoint
- learning_rate_monitor
+seed: 4711
+wandb: false
+tune: false
+train: true
+test: true
load_checkpoint: null
logging: INFO
diff --git a/training/conf/experiment/vqvae_experiment.yaml b/training/conf/experiment/vqvae_experiment.yaml
new file mode 100644
index 0000000..0858c3d
--- /dev/null
+++ b/training/conf/experiment/vqvae_experiment.yaml
@@ -0,0 +1,13 @@
+defaults:
+ - override /network: vqvae
+ - override /criterion: mse
+ - override /optimizer: madgrad
+ - override /lr_scheduler: one_cycle
+ - override /model: lit_vqvae
+ - override /dataset: iam_extended_paragraphs
+ - override /trainer: default
+ - override /callbacks:
+ - wandb
+
+load_checkpoint: null
+logging: INFO
diff --git a/training/conf/logger/wandb.yaml b/training/conf/logger/wandb.yaml
index 552cf00..37bd2fe 100644
--- a/training/conf/logger/wandb.yaml
+++ b/training/conf/logger/wandb.yaml
@@ -2,14 +2,14 @@
wandb:
_target_: pytorch_lightning.loggers.wandb.WandbLogger
- project: "text-recognizer"
+ project: text-recognizer
name: null
save_dir: "."
- offline: False # set True to store all logs only locally
+ offline: false # set True to store all logs only locally
id: null # pass correct id to resume experiment!
# entity: "" # set to name of your wandb team or just remove it
- log_model: False
+ log_model: false
prefix: ""
- job_type: "train"
+ job_type: train
group: ""
tags: []
diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml
index 4e04b85..5341d8e 100644
--- a/training/conf/model/lit_transformer.yaml
+++ b/training/conf/model/lit_transformer.yaml
@@ -1,4 +1,10 @@
+defaults:
+ - mapping: word_piece
+
_target_: text_recognizer.models.transformer.TransformerLitModel
interval: null
monitor: val/loss
-ignore_tokens: ["<s>", "<e>", "<p>"]
+ignore_tokens: [ <s>, <e>, <p> ]
+start_token: <s>
+end_token: <e>
+pad_token: <p>
diff --git a/training/conf/mapping/word_piece.yaml b/training/conf/model/mapping/word_piece.yaml
index 39e2ba4..39e2ba4 100644
--- a/training/conf/mapping/word_piece.yaml
+++ b/training/conf/model/mapping/word_piece.yaml
diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml
index f72e030..7d57a2d 100644
--- a/training/conf/network/conv_transformer.yaml
+++ b/training/conf/network/conv_transformer.yaml
@@ -6,8 +6,6 @@ _target_: text_recognizer.networks.conv_transformer.ConvTransformer
input_dims: [1, 576, 640]
hidden_dim: 256
dropout_rate: 0.2
-max_output_len: 682
-num_classes: 1004
-start_token: <s>
-end_token: <e>
-pad_token: <p>
+max_output_len: 451
+num_classes: 1006
+pad_index: 1002
diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml
index 60c5762..3122de1 100644
--- a/training/conf/network/decoder/transformer_decoder.yaml
+++ b/training/conf/network/decoder/transformer_decoder.yaml
@@ -1,21 +1,20 @@
+defaults:
+ - rotary_emb: null
+
_target_: text_recognizer.networks.transformer.Decoder
dim: 256
depth: 2
num_heads: 8
attn_fn: text_recognizer.networks.transformer.attention.Attention
attn_kwargs:
- num_heads: 8
dim_head: 64
dropout_rate: 0.2
norm_fn: torch.nn.LayerNorm
ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
ff_kwargs:
- dim: 256
dim_out: null
expansion_factor: 4
glu: true
dropout_rate: 0.2
-rotary_emb: null
-rotary_emb_dim: null
cross_attend: true
pre_norm: true
diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml
index 5797741..5ed6552 100644
--- a/training/conf/trainer/default.yaml
+++ b/training/conf/trainer/default.yaml
@@ -1,16 +1,11 @@
-seed: 4711
-wandb: false
-tune: false
-train: true
-test: true
-args:
- stochastic_weight_avg: false
- auto_scale_batch_size: binsearch
- auto_lr_find: false
- gradient_clip_val: 0
- fast_dev_run: false
- gpus: 1
- precision: 16
- max_epochs: 64
- terminate_on_nan: true
- weights_summary: top
+_target_: pytorch_lightning.Trainer
+stochastic_weight_avg: false
+auto_scale_batch_size: binsearch
+auto_lr_find: false
+gradient_clip_val: 0
+fast_dev_run: false
+gpus: 1
+precision: 16
+max_epochs: 64
+terminate_on_nan: true
+weights_summary: top