diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/callbacks/wandb_callbacks.py | 8 | ||||
-rw-r--r-- | training/conf/criterion/label_smoothing.yaml | 0 | ||||
-rw-r--r-- | training/conf/criterion/mse.yaml | 5 | ||||
-rw-r--r-- | training/conf/lr_scheduler/one_cycle.yaml | 10 | ||||
-rw-r--r-- | training/conf/model/lit_vqvae.yaml | 3 | ||||
-rw-r--r-- | training/conf/network/decoder/transformer_decoder.yaml | 21 | ||||
-rw-r--r-- | training/conf/network/encoder/efficientnet.yaml | 6 | ||||
-rw-r--r-- | training/conf/optimizer/madgrad.yaml | 11 | ||||
-rw-r--r-- | training/run.py | 2 | ||||
-rw-r--r-- | training/utils.py | 12 |
10 files changed, 46 insertions, 32 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py index d9d81f6..451b0d5 100644 --- a/training/callbacks/wandb_callbacks.py +++ b/training/callbacks/wandb_callbacks.py @@ -178,13 +178,9 @@ class LogReconstuctedImages(Callback): experiment.log( { f"Reconstructions/{experiment.name}/{stage}": [ - [ - wandb.Image(img), - wandb.Image(rec), - ] + [wandb.Image(img), wandb.Image(rec),] for img, rec in zip( - imgs[: self.num_samples], - reconstructions[: self.num_samples], + imgs[: self.num_samples], reconstructions[: self.num_samples], ) ] } diff --git a/training/conf/criterion/label_smoothing.yaml b/training/conf/criterion/label_smoothing.yaml new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/training/conf/criterion/label_smoothing.yaml diff --git a/training/conf/criterion/mse.yaml b/training/conf/criterion/mse.yaml index 4d89cbc..ffd1403 100644 --- a/training/conf/criterion/mse.yaml +++ b/training/conf/criterion/mse.yaml @@ -1,3 +1,2 @@ -type: MSELoss -args: - reduction: mean +_target_: torch.nn.MSELoss +reduction: mean diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_scheduler/one_cycle.yaml index e8cb5c4..5afdf81 100644 --- a/training/conf/lr_scheduler/one_cycle.yaml +++ b/training/conf/lr_scheduler/one_cycle.yaml @@ -1,11 +1,11 @@ _target_: torch.optim.lr_scheduler.OneCycleLR max_lr: 1.0e-3 -total_steps: None -epochs: None -steps_per_epoch: None +total_steps: null +epochs: null +steps_per_epoch: null pct_start: 0.3 -anneal_strategy: 'cos' -cycle_momentum: True +anneal_strategy: cos +cycle_momentum: true base_momentum: 0.85 max_momentum: 0.95 div_factor: 25.0 diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml index 6be37e5..b337fe6 100644 --- a/training/conf/model/lit_vqvae.yaml +++ b/training/conf/model/lit_vqvae.yaml @@ -1,3 +1,2 @@ _target_: text_recognizer.models.vqvae.VQVAELitModel -args: - mapping: sentence_piece +mapping: sentence_piece diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml new file mode 100644 index 0000000..60c5762 --- /dev/null +++ b/training/conf/network/decoder/transformer_decoder.yaml @@ -0,0 +1,21 @@ +_target_: text_recognizer.networks.transformer.Decoder +dim: 256 +depth: 2 +num_heads: 8 +attn_fn: text_recognizer.networks.transformer.attention.Attention +attn_kwargs: + num_heads: 8 + dim_head: 64 + dropout_rate: 0.2 +norm_fn: torch.nn.LayerNorm +ff_fn: text_recognizer.networks.transformer.mlp.FeedForward +ff_kwargs: + dim: 256 + dim_out: null + expansion_factor: 4 + glu: true + dropout_rate: 0.2 +rotary_emb: null +rotary_emb_dim: null +cross_attend: true +pre_norm: true diff --git a/training/conf/network/encoder/efficientnet.yaml b/training/conf/network/encoder/efficientnet.yaml new file mode 100644 index 0000000..1b9c6da --- /dev/null +++ b/training/conf/network/encoder/efficientnet.yaml @@ -0,0 +1,6 @@ +_target_: text_recognizer.networks.encoders.efficientnet.EfficientNet +arch: b0 +out_channels: 1280 +stochastic_dropout_rate: 0.2 +bn_momentum: 0.99 +bn_eps: 1.0e-3 diff --git a/training/conf/optimizer/madgrad.yaml b/training/conf/optimizer/madgrad.yaml index 2f2cff9..84626d3 100644 --- a/training/conf/optimizer/madgrad.yaml +++ b/training/conf/optimizer/madgrad.yaml @@ -1,6 +1,5 @@ -type: MADGRAD -args: - lr: 1.0e-3 - momentum: 0.9 - weight_decay: 0 - eps: 1.0e-6 +_target_: madgrad.MADGRAD +lr: 1.0e-3 +momentum: 0.9 +weight_decay: 0 +eps: 1.0e-6 diff --git a/training/run.py b/training/run.py index 695a298..f745d61 100644 --- a/training/run.py +++ b/training/run.py @@ -67,7 +67,7 @@ def run(config: DictConfig) -> Optional[float]: log.info("Training network...") trainer.fit(model, datamodule=datamodule) - if config.test:lua/cfg/themes/dark.lua + if config.test: log.info("Testing network...") trainer.test(model, datamodule=datamodule) diff --git a/training/utils.py b/training/utils.py index 88b72b7..ef74f61 100644 --- a/training/utils.py +++ b/training/utils.py @@ -25,9 +25,7 @@ def configure_logging(config: DictConfig) -> None: log.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=config.logging) -def configure_callbacks( - config: DictConfig, -) -> List[Type[Callback]]: +def configure_callbacks(config: DictConfig,) -> List[Type[Callback]]: """Configures Lightning callbacks.""" callbacks = [] if config.get("callbacks"): @@ -95,9 +93,7 @@ def empty(*args: Any, **kwargs: Any) -> None: @rank_zero_only def log_hyperparameters( - config: DictConfig, - model: LightningModule, - trainer: Trainer, + config: DictConfig, model: LightningModule, trainer: Trainer, ) -> None: """This method saves hyperparameters with the logger.""" hparams = {} @@ -127,9 +123,7 @@ def log_hyperparameters( trainer.logger.log_hyperparams = empty -def finish( - logger: List[Type[LightningLoggerBase]], -) -> None: +def finish(logger: List[Type[LightningLoggerBase]],) -> None: """Makes sure everything closed properly.""" for lg in logger: if isinstance(lg, WandbLogger): |