diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/conf/criterion/ctc.yaml | 2 | ||||
-rw-r--r-- | training/conf/experiment/conv_perceiver_lines.yaml | 76 | ||||
-rw-r--r-- | training/conf/experiment/vq_transformer_lines.yaml | 81 | ||||
-rw-r--r-- | training/conf/model/lit_perceiver.yaml | 5 | ||||
-rw-r--r-- | training/conf/model/lit_vq_transformer.yaml | 5 | ||||
-rw-r--r-- | training/conf/network/conv_perceiver.yaml | 37 | ||||
-rw-r--r-- | training/conf/network/vq_transformer.yaml | 65 |
7 files changed, 0 insertions, 271 deletions
diff --git a/training/conf/criterion/ctc.yaml b/training/conf/criterion/ctc.yaml deleted file mode 100644 index 0677c06..0000000 --- a/training/conf/criterion/ctc.yaml +++ /dev/null @@ -1,2 +0,0 @@ -_target_: torch.nn.CTCLoss -zero_infinity: true diff --git a/training/conf/experiment/conv_perceiver_lines.yaml b/training/conf/experiment/conv_perceiver_lines.yaml deleted file mode 100644 index 26fe232..0000000 --- a/training/conf/experiment/conv_perceiver_lines.yaml +++ /dev/null @@ -1,76 +0,0 @@ -# @package _global_ - -defaults: - - override /criterion: cross_entropy - - override /callbacks: htr - - override /datamodule: iam_lines - - override /network: conv_perceiver - - override /model: lit_perceiver - - override /lr_scheduler: null - - override /optimizer: null - -tags: [lines, perceiver] -epochs: &epochs 260 -ignore_index: &ignore_index 3 -num_classes: &num_classes 57 -max_output_len: &max_output_len 89 -summary: [[1, 1, 56, 1024]] - -logger: - wandb: - tags: ${tags} - -criterion: - ignore_index: *ignore_index - # label_smoothing: 0.1 - -callbacks: - stochastic_weight_averaging: - _target_: pytorch_lightning.callbacks.StochasticWeightAveraging - swa_epoch_start: 0.75 - swa_lrs: 1.0e-5 - annealing_epochs: 10 - annealing_strategy: cos - device: null - -optimizer: - _target_: adan_pytorch.Adan - lr: 1.0e-4 - betas: [0.02, 0.08, 0.01] - weight_decay: 0.02 - -lr_scheduler: - _target_: torch.optim.lr_scheduler.ReduceLROnPlateau - mode: min - factor: 0.8 - patience: 10 - threshold: 1.0e-4 - threshold_mode: rel - cooldown: 0 - min_lr: 1.0e-5 - eps: 1.0e-8 - verbose: false - interval: epoch - monitor: val/cer - -datamodule: - batch_size: 8 - train_fraction: 0.95 - -network: - input_dims: [1, 1, 56, 1024] - num_classes: *num_classes - pad_index: *ignore_index - encoder: - depth: 5 - decoder: - depth: 6 - -model: - max_output_len: *max_output_len - -trainer: - gradient_clip_val: 1.0 - stochastic_weight_avg: true - max_epochs: *epochs - accumulate_grad_batches: 1 diff --git a/training/conf/experiment/vq_transformer_lines.yaml b/training/conf/experiment/vq_transformer_lines.yaml deleted file mode 100644 index dbd8a3b..0000000 --- a/training/conf/experiment/vq_transformer_lines.yaml +++ /dev/null @@ -1,81 +0,0 @@ -# @package _global_ - -defaults: - - override /criterion: cross_entropy - - override /callbacks: htr - - override /datamodule: iam_lines - - override /network: vq_transformer - - override /model: lit_vq_transformer - - override /lr_scheduler: null - - override /optimizer: null - -tags: [lines] -epochs: &epochs 200 -ignore_index: &ignore_index 3 -num_classes: &num_classes 57 -max_output_len: &max_output_len 89 -summary: [[1, 1, 56, 1024], [1, 89]] - -logger: - wandb: - tags: ${tags} - # id: 342qvr1p - -criterion: - ignore_index: *ignore_index - label_smoothing: 0.05 - -callbacks: - stochastic_weight_averaging: - _target_: pytorch_lightning.callbacks.StochasticWeightAveraging - swa_epoch_start: 0.75 - swa_lrs: 1.0e-5 - annealing_epochs: 10 - annealing_strategy: cos - device: null - -optimizer: - _target_: torch.optim.RAdam - lr: 3.0e-4 - betas: [0.9, 0.999] - weight_decay: 0 - eps: 1.0e-8 - -lr_scheduler: - _target_: torch.optim.lr_scheduler.ReduceLROnPlateau - mode: min - factor: 0.8 - patience: 10 - threshold: 1.0e-4 - threshold_mode: rel - cooldown: 0 - min_lr: 1.0e-5 - eps: 1.0e-8 - verbose: false - interval: epoch - monitor: val/cer - -datamodule: - batch_size: 8 - train_fraction: 0.95 - -network: - input_dims: [1, 1, 56, 1024] - num_classes: *num_classes - pad_index: *ignore_index - encoder: - depth: 5 - decoder: - depth: 6 - pixel_embedding: - shape: [1, 127] - -model: - max_output_len: *max_output_len - vq_loss_weight: 0.1 - -trainer: - gradient_clip_val: 1.0 - max_epochs: *epochs - accumulate_grad_batches: 1 - # resume_from_checkpoint: /home/aktersnurra/projects/text-recognizer/training/logs/runs/2022-06-27/00-37-40/checkpoints/last.ckpt diff --git a/training/conf/model/lit_perceiver.yaml b/training/conf/model/lit_perceiver.yaml deleted file mode 100644 index 6d1ec82..0000000 --- a/training/conf/model/lit_perceiver.yaml +++ /dev/null @@ -1,5 +0,0 @@ -_target_: text_recognizer.models.LitPerceiver -max_output_len: 682 -start_token: <s> -end_token: <e> -pad_token: <p> diff --git a/training/conf/model/lit_vq_transformer.yaml b/training/conf/model/lit_vq_transformer.yaml deleted file mode 100644 index 4173151..0000000 --- a/training/conf/model/lit_vq_transformer.yaml +++ /dev/null @@ -1,5 +0,0 @@ -_target_: text_recognizer.models.LitVqTransformer -max_output_len: 682 -start_token: <s> -end_token: <e> -pad_token: <p> diff --git a/training/conf/network/conv_perceiver.yaml b/training/conf/network/conv_perceiver.yaml deleted file mode 100644 index 2e12db9..0000000 --- a/training/conf/network/conv_perceiver.yaml +++ /dev/null @@ -1,37 +0,0 @@ -_target_: text_recognizer.networks.ConvPerceiver -input_dims: [1, 1, 576, 640] -hidden_dim: &hidden_dim 128 -num_classes: &num_classes 58 -max_length: &max_length 89 -num_queries: *max_length -queries_dim: &queries_dim 64 -pad_index: 3 -encoder: - _target_: text_recognizer.networks.EfficientNet - arch: b0 - stochastic_dropout_rate: 0.2 - bn_momentum: 0.99 - bn_eps: 1.0e-3 - depth: 5 - out_channels: *hidden_dim -decoder: - _target_: text_recognizer.networks.perceiver.PerceiverIO - dim: 192 - cross_heads: 1 - cross_head_dim: 64 - num_latents: 256 - latent_dim: 512 - latent_heads: 8 - depth: 6 - queries_dim: 128 - logits_dim: *num_classes -pixel_embedding: - _target_: text_recognizer.networks.transformer.embeddings.axial.AxialPositionalEmbeddingImage - dim: 64 - axial_shape: [3, 64] - axial_dims: [32, 32] -query_pos_emb: - _target_: text_recognizer.networks.transformer.embeddings.absolute.AbsolutePositionalEmbedding - dim: 64 - max_seq_len: *max_length - l2norm_embed: true diff --git a/training/conf/network/vq_transformer.yaml b/training/conf/network/vq_transformer.yaml deleted file mode 100644 index d62a4b7..0000000 --- a/training/conf/network/vq_transformer.yaml +++ /dev/null @@ -1,65 +0,0 @@ -_target_: text_recognizer.networks.VqTransformer -input_dims: [1, 1, 576, 640] -hidden_dim: &hidden_dim 144 -num_classes: 58 -pad_index: 3 -encoder: - _target_: text_recognizer.networks.EfficientNet - arch: b0 - stochastic_dropout_rate: 0.2 - bn_momentum: 0.99 - bn_eps: 1.0e-3 - depth: 5 - out_channels: *hidden_dim -decoder: - _target_: text_recognizer.networks.transformer.Decoder - depth: 6 - block: - _target_: text_recognizer.networks.transformer.DecoderBlock - self_attn: - _target_: text_recognizer.networks.transformer.Attention - dim: *hidden_dim - num_heads: 8 - dim_head: 64 - dropout_rate: &dropout_rate 0.4 - causal: true - rotary_embedding: - _target_: text_recognizer.networks.transformer.RotaryEmbedding - dim: 64 - cross_attn: - _target_: text_recognizer.networks.transformer.Attention - dim: *hidden_dim - num_heads: 8 - dim_head: 64 - dropout_rate: *dropout_rate - causal: false - norm: - _target_: text_recognizer.networks.transformer.RMSNorm - dim: *hidden_dim - ff: - _target_: text_recognizer.networks.transformer.FeedForward - dim: *hidden_dim - dim_out: null - expansion_factor: 2 - glu: true - dropout_rate: *dropout_rate -pixel_embedding: - _target_: text_recognizer.networks.transformer.AxialPositionalEmbedding - dim: *hidden_dim - shape: [18, 79] -quantizer: - _target_: text_recognizer.networks.quantizer.VectorQuantizer - input_dim: *hidden_dim - codebook: - _target_: text_recognizer.networks.quantizer.CosineSimilarityCodebook - dim: 16 - codebook_size: 64 - kmeans_init: true - kmeans_iters: 10 - decay: 0.8 - eps: 1.0e-5 - threshold_dead: 2 - temperature: 0.0 - commitment: 0.25 - ort_reg_weight: 10 - ort_reg_max_codes: 64 |