diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/conf/callbacks/wandb.yaml | 20 | ||||
-rw-r--r-- | training/conf/callbacks/wandb/checkpoints.yaml | 4 | ||||
-rw-r--r-- | training/conf/callbacks/wandb/code.yaml | 3 | ||||
-rw-r--r-- | training/conf/callbacks/wandb/image_reconstructions.yaml | 0 | ||||
-rw-r--r-- | training/conf/callbacks/wandb/ocr_predictions.yaml | 3 | ||||
-rw-r--r-- | training/conf/callbacks/wandb/watch.yaml | 4 | ||||
-rw-r--r-- | training/conf/callbacks/wandb_ocr.yaml | 6 | ||||
-rw-r--r-- | training/conf/config.yaml | 18 | ||||
-rw-r--r-- | training/conf/criterion/label_smoothing.yaml | 2 | ||||
-rw-r--r-- | training/conf/hydra/default.yaml | 6 | ||||
-rw-r--r-- | training/conf/mapping/word_piece.yaml (renamed from training/conf/model/mapping/word_piece.yaml) | 4 | ||||
-rw-r--r-- | training/conf/model/lit_transformer.yaml | 5 | ||||
-rw-r--r-- | training/conf/network/conv_transformer.yaml | 2 | ||||
-rw-r--r-- | training/conf/network/decoder/transformer_decoder.yaml | 4 | ||||
-rw-r--r-- | training/conf/trainer/default.yaml | 6 | ||||
-rw-r--r-- | training/run.py | 11 | ||||
-rw-r--r-- | training/utils.py | 2 |
17 files changed, 55 insertions, 45 deletions
diff --git a/training/conf/callbacks/wandb.yaml b/training/conf/callbacks/wandb.yaml deleted file mode 100644 index 0017e11..0000000 --- a/training/conf/callbacks/wandb.yaml +++ /dev/null @@ -1,20 +0,0 @@ -defaults: - - default.yaml - -watch_model: - _target_: callbacks.wandb_callbacks.WatchModel - log: all - log_freq: 100 - -upload_code_as_artifact: - _target_: callbacks.wandb_callbacks.UploadCodeAsArtifact - project_dir: ${work_dir}/text_recognizer - -upload_ckpts_as_artifact: - _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact - ckpt_dir: checkpoints/ - upload_best_only: true - -log_text_predictions: - _target_: callbacks.wandb_callbacks.LogTextPredictions - num_samples: 8 diff --git a/training/conf/callbacks/wandb/checkpoints.yaml b/training/conf/callbacks/wandb/checkpoints.yaml new file mode 100644 index 0000000..a4a16ff --- /dev/null +++ b/training/conf/callbacks/wandb/checkpoints.yaml @@ -0,0 +1,4 @@ +upload_ckpts_as_artifact: + _target_: callbacks.wandb_callbacks.UploadCheckpointsAsArtifact + ckpt_dir: checkpoints/ + upload_best_only: true diff --git a/training/conf/callbacks/wandb/code.yaml b/training/conf/callbacks/wandb/code.yaml new file mode 100644 index 0000000..35f6ea3 --- /dev/null +++ b/training/conf/callbacks/wandb/code.yaml @@ -0,0 +1,3 @@ +upload_code_as_artifact: + _target_: callbacks.wandb_callbacks.UploadCodeAsArtifact + project_dir: ${work_dir}/text_recognizer diff --git a/training/conf/callbacks/wandb/image_reconstructions.yaml b/training/conf/callbacks/wandb/image_reconstructions.yaml new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/training/conf/callbacks/wandb/image_reconstructions.yaml diff --git a/training/conf/callbacks/wandb/ocr_predictions.yaml b/training/conf/callbacks/wandb/ocr_predictions.yaml new file mode 100644 index 0000000..573fa96 --- /dev/null +++ b/training/conf/callbacks/wandb/ocr_predictions.yaml @@ -0,0 +1,3 @@ +log_text_predictions: + _target_: callbacks.wandb_callbacks.LogTextPredictions + num_samples: 8 diff --git a/training/conf/callbacks/wandb/watch.yaml b/training/conf/callbacks/wandb/watch.yaml new file mode 100644 index 0000000..511608c --- /dev/null +++ b/training/conf/callbacks/wandb/watch.yaml @@ -0,0 +1,4 @@ +watch_model: + _target_: callbacks.wandb_callbacks.WatchModel + log: all + log_freq: 100 diff --git a/training/conf/callbacks/wandb_ocr.yaml b/training/conf/callbacks/wandb_ocr.yaml new file mode 100644 index 0000000..efa3dda --- /dev/null +++ b/training/conf/callbacks/wandb_ocr.yaml @@ -0,0 +1,6 @@ +defaults: + - default + - wandb/watch + - wandb/code + - wandb/checkpoints + - wandb/ocr_predictions diff --git a/training/conf/config.yaml b/training/conf/config.yaml index a8e718e..93215ed 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -1,19 +1,17 @@ defaults: - - network: vqvae - - criterion: mse - - optimizer: madgrad - - lr_scheduler: one_cycle - - model: lit_vqvae + - callbacks: wandb_ocr + - criterion: label_smoothing - dataset: iam_extended_paragraphs + - hydra: default + - lr_scheduler: one_cycle + - mapping: word_piece + - model: lit_transformer + - network: conv_transformer + - optimizer: madgrad - trainer: default - - callbacks: - - checkpoint - - learning_rate_monitor seed: 4711 -wandb: false tune: false train: true test: true -load_checkpoint: null logging: INFO diff --git a/training/conf/criterion/label_smoothing.yaml b/training/conf/criterion/label_smoothing.yaml index ee47c59..13daba8 100644 --- a/training/conf/criterion/label_smoothing.yaml +++ b/training/conf/criterion/label_smoothing.yaml @@ -1,4 +1,4 @@ -_target_: text_recognizer.criterion.label_smoothing +_target_: text_recognizer.criterion.label_smoothing.LabelSmoothingLoss label_smoothing: 0.1 vocab_size: 1006 ignore_index: 1002 diff --git a/training/conf/hydra/default.yaml b/training/conf/hydra/default.yaml new file mode 100644 index 0000000..dfd9721 --- /dev/null +++ b/training/conf/hydra/default.yaml @@ -0,0 +1,6 @@ +# output paths for hydra logs +run: + dir: logs/runs/${now:%Y-%m-%d}/${now:%H-%M-%S} +sweep: + dir: logs/multiruns/${now:%Y-%m-%d_%H-%M-%S} + subdir: ${hydra.job.num} diff --git a/training/conf/model/mapping/word_piece.yaml b/training/conf/mapping/word_piece.yaml index 39e2ba4..3792523 100644 --- a/training/conf/model/mapping/word_piece.yaml +++ b/training/conf/mapping/word_piece.yaml @@ -5,5 +5,5 @@ lexicon: iamdb_1kwp_lex_1000.txt data_dir: null use_words: false prepend_wordsep: false -special_tokens: ["<s>", "<e>", "<p>"] -extra_symbols: ["\n"] +special_tokens: [ <s>, <e>, <p> ] +extra_symbols: [ \n ] diff --git a/training/conf/model/lit_transformer.yaml b/training/conf/model/lit_transformer.yaml index 5341d8e..6ffde4e 100644 --- a/training/conf/model/lit_transformer.yaml +++ b/training/conf/model/lit_transformer.yaml @@ -1,8 +1,5 @@ -defaults: - - mapping: word_piece - _target_: text_recognizer.models.transformer.TransformerLitModel -interval: null +interval: step monitor: val/loss ignore_tokens: [ <s>, <e>, <p> ] start_token: <s> diff --git a/training/conf/network/conv_transformer.yaml b/training/conf/network/conv_transformer.yaml index 7d57a2d..a97157d 100644 --- a/training/conf/network/conv_transformer.yaml +++ b/training/conf/network/conv_transformer.yaml @@ -4,7 +4,7 @@ defaults: _target_: text_recognizer.networks.conv_transformer.ConvTransformer input_dims: [1, 576, 640] -hidden_dim: 256 +hidden_dim: 96 dropout_rate: 0.2 max_output_len: 451 num_classes: 1006 diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml index 3122de1..90b9d8a 100644 --- a/training/conf/network/decoder/transformer_decoder.yaml +++ b/training/conf/network/decoder/transformer_decoder.yaml @@ -2,12 +2,12 @@ defaults: - rotary_emb: null _target_: text_recognizer.networks.transformer.Decoder -dim: 256 +dim: 96 depth: 2 num_heads: 8 attn_fn: text_recognizer.networks.transformer.attention.Attention attn_kwargs: - dim_head: 64 + dim_head: 16 dropout_rate: 0.2 norm_fn: torch.nn.LayerNorm ff_fn: text_recognizer.networks.transformer.mlp.FeedForward diff --git a/training/conf/trainer/default.yaml b/training/conf/trainer/default.yaml index 5ed6552..c665adc 100644 --- a/training/conf/trainer/default.yaml +++ b/training/conf/trainer/default.yaml @@ -6,6 +6,10 @@ gradient_clip_val: 0 fast_dev_run: false gpus: 1 precision: 16 -max_epochs: 64 +max_epochs: 512 terminate_on_nan: true weights_summary: top +limit_train_batches: 1.0 +limit_val_batches: 1.0 +limit_test_batches: 1.0 +resume_from_checkpoint: null diff --git a/training/run.py b/training/run.py index d88a8f6..30479c6 100644 --- a/training/run.py +++ b/training/run.py @@ -2,7 +2,7 @@ from typing import List, Optional, Type import hydra -import loguru.logger as log +from loguru import logger as log from omegaconf import DictConfig from pytorch_lightning import ( Callback, @@ -12,6 +12,7 @@ from pytorch_lightning import ( Trainer, ) from pytorch_lightning.loggers import LightningLoggerBase +from text_recognizer.data.mappings import AbstractMapping from torch import nn import utils @@ -25,15 +26,19 @@ def run(config: DictConfig) -> Optional[float]: if config.get("seed"): seed_everything(config.seed) + log.info(f"Instantiating mapping <{config.mapping._target_}>") + mapping: AbstractMapping = hydra.utils.instantiate(config.mapping) + log.info(f"Instantiating datamodule <{config.datamodule._target_}>") - datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule) + datamodule: LightningDataModule = hydra.utils.instantiate(config.datamodule, mapping=mapping) log.info(f"Instantiating network <{config.network._target_}>") - network: nn.Module = hydra.utils.instantiate(config.network, **datamodule.config()) + network: nn.Module = hydra.utils.instantiate(config.network) log.info(f"Instantiating model <{config.model._target_}>") model: LightningModule = hydra.utils.instantiate( **config.model, + mapping=mapping, network=network, criterion_config=config.criterion, optimizer_config=config.optimizer, diff --git a/training/utils.py b/training/utils.py index 564b9bb..ef74f61 100644 --- a/training/utils.py +++ b/training/utils.py @@ -3,7 +3,7 @@ from typing import Any, List, Type import warnings import hydra -import loguru.logger as log +from loguru import logger as log from omegaconf import DictConfig, OmegaConf from pytorch_lightning import ( Callback, |