diff options
-rw-r--r-- | README.md | 6 | ||||
-rw-r--r-- | text_recognizer/models/vqgan.py | 12 | ||||
-rw-r--r-- | training/conf/experiment/vqgan.yaml | 17 | ||||
-rw-r--r-- | training/conf/experiment/vqvae.yaml | 8 | ||||
-rw-r--r-- | training/conf/lr_schedulers/cosine_annealing.yaml (renamed from training/conf/lr_scheduler/cosine_annealing.yaml) | 0 | ||||
-rw-r--r-- | training/conf/lr_schedulers/one_cycle.yaml (renamed from training/conf/lr_scheduler/one_cycle.yaml) | 0 | ||||
-rw-r--r-- | training/conf/optimizers/madgrad.yaml (renamed from training/conf/optimizer/madgrad.yaml) | 0 |
7 files changed, 19 insertions, 24 deletions
@@ -27,11 +27,7 @@ python build-transitions --tokens iamdb_1kwp_tokens_1000.txt --lexicon iamdb_1kw (TODO: Not working atm, needed for GTN loss function) ## Todo -- [x] Efficient-net b0 + transformer decoder -- [x] Load everything with hydra, get it to work -- [x] Train network -- [ ] Weight init -- [ ] patchgan loss +- [ ] patchgan loss FIX THIS!! LOOK AT TAMING TRANSFORMER, MORE SPECIFICALLY SEND LAYER AND COMPUTE COEFFICIENT - [ ] Get VQVAE2 to work and not get loss NAN - [ ] Local attention for target sequence - [ ] Rotary embedding for target sequence diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py index 80653b6..7c707b1 100644 --- a/text_recognizer/models/vqgan.py +++ b/text_recognizer/models/vqgan.py @@ -39,11 +39,8 @@ class VQGANLitModel(BaseLitModel): "train/loss", loss, prog_bar=True, - logger=True, - on_step=True, - on_epoch=True, ) - self.log_dict(log, prog_bar=False, logger=True, on_step=True, on_epoch=True) + self.log_dict(log, logger=True, on_step=True, on_epoch=True) return loss if optimizer_idx == 1: @@ -58,11 +55,8 @@ class VQGANLitModel(BaseLitModel): "train/discriminator_loss", loss, prog_bar=True, - logger=True, - on_step=True, - on_epoch=True, ) - self.log_dict(log, prog_bar=False, logger=True, on_step=True, on_epoch=True) + self.log_dict(log, logger=True, on_step=True, on_epoch=True) return loss def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: @@ -78,7 +72,7 @@ class VQGANLitModel(BaseLitModel): stage="val", ) self.log( - "val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True + "val/loss", loss, prog_bar=True, ) self.log_dict(log) diff --git a/training/conf/experiment/vqgan.yaml b/training/conf/experiment/vqgan.yaml index 570e7f9..554ec9e 100644 --- a/training/conf/experiment/vqgan.yaml +++ b/training/conf/experiment/vqgan.yaml @@ -8,6 +8,19 @@ defaults: - override /optimizers: null - override /lr_schedulers: null +criterion: + _target_: text_recognizer.criterions.vqgan_loss.VQGANLoss + reconstruction_loss: + _target_: torch.nn.L1Loss + reduction: mean + discriminator: + _target_: text_recognizer.criterions.n_layer_discriminator.NLayerDiscriminator + in_channels: 1 + num_channels: 32 + num_layers: 3 + vq_loss_weight: 0.8 + discriminator_weight: 0.6 + datamodule: batch_size: 8 @@ -33,7 +46,7 @@ lr_schedulers: optimizers: generator: _target_: madgrad.MADGRAD - lr: 2.0e-5 + lr: 4.5e-6 momentum: 0.5 weight_decay: 0 eps: 1.0e-6 @@ -42,7 +55,7 @@ optimizers: discriminator: _target_: madgrad.MADGRAD - lr: 2.0e-5 + lr: 4.5e-6 momentum: 0.5 weight_decay: 0 eps: 1.0e-6 diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml index 397a039..8dbb257 100644 --- a/training/conf/experiment/vqvae.yaml +++ b/training/conf/experiment/vqvae.yaml @@ -10,16 +10,8 @@ defaults: trainer: max_epochs: 256 - # gradient_clip_val: 0.25 datamodule: batch_size: 8 -# lr_scheduler: - # epochs: 64 - # steps_per_epoch: 1245 - -# optimizer: - # lr: 1.0e-3 - summary: null diff --git a/training/conf/lr_scheduler/cosine_annealing.yaml b/training/conf/lr_schedulers/cosine_annealing.yaml index c53ee3a..c53ee3a 100644 --- a/training/conf/lr_scheduler/cosine_annealing.yaml +++ b/training/conf/lr_schedulers/cosine_annealing.yaml diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_schedulers/one_cycle.yaml index c60577a..c60577a 100644 --- a/training/conf/lr_scheduler/one_cycle.yaml +++ b/training/conf/lr_schedulers/one_cycle.yaml diff --git a/training/conf/optimizer/madgrad.yaml b/training/conf/optimizers/madgrad.yaml index a6c059d..a6c059d 100644 --- a/training/conf/optimizer/madgrad.yaml +++ b/training/conf/optimizers/madgrad.yaml |