summaryrefslogtreecommitdiff
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/callbacks/wandb_callbacks.py8
-rw-r--r--training/conf/criterion/label_smoothing.yaml0
-rw-r--r--training/conf/criterion/mse.yaml5
-rw-r--r--training/conf/lr_scheduler/one_cycle.yaml10
-rw-r--r--training/conf/model/lit_vqvae.yaml3
-rw-r--r--training/conf/network/decoder/transformer_decoder.yaml21
-rw-r--r--training/conf/network/encoder/efficientnet.yaml6
-rw-r--r--training/conf/optimizer/madgrad.yaml11
-rw-r--r--training/run.py2
-rw-r--r--training/utils.py12
10 files changed, 46 insertions, 32 deletions
diff --git a/training/callbacks/wandb_callbacks.py b/training/callbacks/wandb_callbacks.py
index d9d81f6..451b0d5 100644
--- a/training/callbacks/wandb_callbacks.py
+++ b/training/callbacks/wandb_callbacks.py
@@ -178,13 +178,9 @@ class LogReconstuctedImages(Callback):
experiment.log(
{
f"Reconstructions/{experiment.name}/{stage}": [
- [
- wandb.Image(img),
- wandb.Image(rec),
- ]
+ [wandb.Image(img), wandb.Image(rec),]
for img, rec in zip(
- imgs[: self.num_samples],
- reconstructions[: self.num_samples],
+ imgs[: self.num_samples], reconstructions[: self.num_samples],
)
]
}
diff --git a/training/conf/criterion/label_smoothing.yaml b/training/conf/criterion/label_smoothing.yaml
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/training/conf/criterion/label_smoothing.yaml
diff --git a/training/conf/criterion/mse.yaml b/training/conf/criterion/mse.yaml
index 4d89cbc..ffd1403 100644
--- a/training/conf/criterion/mse.yaml
+++ b/training/conf/criterion/mse.yaml
@@ -1,3 +1,2 @@
-type: MSELoss
-args:
- reduction: mean
+_target_: torch.nn.MSELoss
+reduction: mean
diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_scheduler/one_cycle.yaml
index e8cb5c4..5afdf81 100644
--- a/training/conf/lr_scheduler/one_cycle.yaml
+++ b/training/conf/lr_scheduler/one_cycle.yaml
@@ -1,11 +1,11 @@
_target_: torch.optim.lr_scheduler.OneCycleLR
max_lr: 1.0e-3
-total_steps: None
-epochs: None
-steps_per_epoch: None
+total_steps: null
+epochs: null
+steps_per_epoch: null
pct_start: 0.3
-anneal_strategy: 'cos'
-cycle_momentum: True
+anneal_strategy: cos
+cycle_momentum: true
base_momentum: 0.85
max_momentum: 0.95
div_factor: 25.0
diff --git a/training/conf/model/lit_vqvae.yaml b/training/conf/model/lit_vqvae.yaml
index 6be37e5..b337fe6 100644
--- a/training/conf/model/lit_vqvae.yaml
+++ b/training/conf/model/lit_vqvae.yaml
@@ -1,3 +1,2 @@
_target_: text_recognizer.models.vqvae.VQVAELitModel
-args:
- mapping: sentence_piece
+mapping: sentence_piece
diff --git a/training/conf/network/decoder/transformer_decoder.yaml b/training/conf/network/decoder/transformer_decoder.yaml
new file mode 100644
index 0000000..60c5762
--- /dev/null
+++ b/training/conf/network/decoder/transformer_decoder.yaml
@@ -0,0 +1,21 @@
+_target_: text_recognizer.networks.transformer.Decoder
+dim: 256
+depth: 2
+num_heads: 8
+attn_fn: text_recognizer.networks.transformer.attention.Attention
+attn_kwargs:
+ num_heads: 8
+ dim_head: 64
+ dropout_rate: 0.2
+norm_fn: torch.nn.LayerNorm
+ff_fn: text_recognizer.networks.transformer.mlp.FeedForward
+ff_kwargs:
+ dim: 256
+ dim_out: null
+ expansion_factor: 4
+ glu: true
+ dropout_rate: 0.2
+rotary_emb: null
+rotary_emb_dim: null
+cross_attend: true
+pre_norm: true
diff --git a/training/conf/network/encoder/efficientnet.yaml b/training/conf/network/encoder/efficientnet.yaml
new file mode 100644
index 0000000..1b9c6da
--- /dev/null
+++ b/training/conf/network/encoder/efficientnet.yaml
@@ -0,0 +1,6 @@
+_target_: text_recognizer.networks.encoders.efficientnet.EfficientNet
+arch: b0
+out_channels: 1280
+stochastic_dropout_rate: 0.2
+bn_momentum: 0.99
+bn_eps: 1.0e-3
diff --git a/training/conf/optimizer/madgrad.yaml b/training/conf/optimizer/madgrad.yaml
index 2f2cff9..84626d3 100644
--- a/training/conf/optimizer/madgrad.yaml
+++ b/training/conf/optimizer/madgrad.yaml
@@ -1,6 +1,5 @@
-type: MADGRAD
-args:
- lr: 1.0e-3
- momentum: 0.9
- weight_decay: 0
- eps: 1.0e-6
+_target_: madgrad.MADGRAD
+lr: 1.0e-3
+momentum: 0.9
+weight_decay: 0
+eps: 1.0e-6
diff --git a/training/run.py b/training/run.py
index 695a298..f745d61 100644
--- a/training/run.py
+++ b/training/run.py
@@ -67,7 +67,7 @@ def run(config: DictConfig) -> Optional[float]:
log.info("Training network...")
trainer.fit(model, datamodule=datamodule)
- if config.test:lua/cfg/themes/dark.lua
+ if config.test:
log.info("Testing network...")
trainer.test(model, datamodule=datamodule)
diff --git a/training/utils.py b/training/utils.py
index 88b72b7..ef74f61 100644
--- a/training/utils.py
+++ b/training/utils.py
@@ -25,9 +25,7 @@ def configure_logging(config: DictConfig) -> None:
log.add(lambda msg: tqdm.write(msg, end=""), colorize=True, level=config.logging)
-def configure_callbacks(
- config: DictConfig,
-) -> List[Type[Callback]]:
+def configure_callbacks(config: DictConfig,) -> List[Type[Callback]]:
"""Configures Lightning callbacks."""
callbacks = []
if config.get("callbacks"):
@@ -95,9 +93,7 @@ def empty(*args: Any, **kwargs: Any) -> None:
@rank_zero_only
def log_hyperparameters(
- config: DictConfig,
- model: LightningModule,
- trainer: Trainer,
+ config: DictConfig, model: LightningModule, trainer: Trainer,
) -> None:
"""This method saves hyperparameters with the logger."""
hparams = {}
@@ -127,9 +123,7 @@ def log_hyperparameters(
trainer.logger.log_hyperparams = empty
-def finish(
- logger: List[Type[LightningLoggerBase]],
-) -> None:
+def finish(logger: List[Type[LightningLoggerBase]],) -> None:
"""Makes sure everything closed properly."""
for lg in logger:
if isinstance(lg, WandbLogger):