summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-10 23:25:13 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-10 23:25:13 +0200
commit441b7484348953deb7c94150675d54583ef5a81a (patch)
tree6a3e537f98a5c34fe166db7f4c60552b1212b8f5
parent82f4acabe24e5171c40afa2939a4777ba87bcc30 (diff)
Update to config and logging in VQGAN
-rw-r--r--README.md6
-rw-r--r--text_recognizer/models/vqgan.py12
-rw-r--r--training/conf/experiment/vqgan.yaml17
-rw-r--r--training/conf/experiment/vqvae.yaml8
-rw-r--r--training/conf/lr_schedulers/cosine_annealing.yaml (renamed from training/conf/lr_scheduler/cosine_annealing.yaml)0
-rw-r--r--training/conf/lr_schedulers/one_cycle.yaml (renamed from training/conf/lr_scheduler/one_cycle.yaml)0
-rw-r--r--training/conf/optimizers/madgrad.yaml (renamed from training/conf/optimizer/madgrad.yaml)0
7 files changed, 19 insertions, 24 deletions
diff --git a/README.md b/README.md
index 45314a4..ef99b4a 100644
--- a/README.md
+++ b/README.md
@@ -27,11 +27,7 @@ python build-transitions --tokens iamdb_1kwp_tokens_1000.txt --lexicon iamdb_1kw
(TODO: Not working atm, needed for GTN loss function)
## Todo
-- [x] Efficient-net b0 + transformer decoder
-- [x] Load everything with hydra, get it to work
-- [x] Train network
-- [ ] Weight init
-- [ ] patchgan loss
+- [ ] patchgan loss FIX THIS!! LOOK AT TAMING TRANSFORMER, MORE SPECIFICALLY SEND LAYER AND COMPUTE COEFFICIENT
- [ ] Get VQVAE2 to work and not get loss NAN
- [ ] Local attention for target sequence
- [ ] Rotary embedding for target sequence
diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py
index 80653b6..7c707b1 100644
--- a/text_recognizer/models/vqgan.py
+++ b/text_recognizer/models/vqgan.py
@@ -39,11 +39,8 @@ class VQGANLitModel(BaseLitModel):
"train/loss",
loss,
prog_bar=True,
- logger=True,
- on_step=True,
- on_epoch=True,
)
- self.log_dict(log, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log, logger=True, on_step=True, on_epoch=True)
return loss
if optimizer_idx == 1:
@@ -58,11 +55,8 @@ class VQGANLitModel(BaseLitModel):
"train/discriminator_loss",
loss,
prog_bar=True,
- logger=True,
- on_step=True,
- on_epoch=True,
)
- self.log_dict(log, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ self.log_dict(log, logger=True, on_step=True, on_epoch=True)
return loss
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
@@ -78,7 +72,7 @@ class VQGANLitModel(BaseLitModel):
stage="val",
)
self.log(
- "val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True
+ "val/loss", loss, prog_bar=True,
)
self.log_dict(log)
diff --git a/training/conf/experiment/vqgan.yaml b/training/conf/experiment/vqgan.yaml
index 570e7f9..554ec9e 100644
--- a/training/conf/experiment/vqgan.yaml
+++ b/training/conf/experiment/vqgan.yaml
@@ -8,6 +8,19 @@ defaults:
- override /optimizers: null
- override /lr_schedulers: null
+criterion:
+ _target_: text_recognizer.criterions.vqgan_loss.VQGANLoss
+ reconstruction_loss:
+ _target_: torch.nn.L1Loss
+ reduction: mean
+ discriminator:
+ _target_: text_recognizer.criterions.n_layer_discriminator.NLayerDiscriminator
+ in_channels: 1
+ num_channels: 32
+ num_layers: 3
+ vq_loss_weight: 0.8
+ discriminator_weight: 0.6
+
datamodule:
batch_size: 8
@@ -33,7 +46,7 @@ lr_schedulers:
optimizers:
generator:
_target_: madgrad.MADGRAD
- lr: 2.0e-5
+ lr: 4.5e-6
momentum: 0.5
weight_decay: 0
eps: 1.0e-6
@@ -42,7 +55,7 @@ optimizers:
discriminator:
_target_: madgrad.MADGRAD
- lr: 2.0e-5
+ lr: 4.5e-6
momentum: 0.5
weight_decay: 0
eps: 1.0e-6
diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml
index 397a039..8dbb257 100644
--- a/training/conf/experiment/vqvae.yaml
+++ b/training/conf/experiment/vqvae.yaml
@@ -10,16 +10,8 @@ defaults:
trainer:
max_epochs: 256
- # gradient_clip_val: 0.25
datamodule:
batch_size: 8
-# lr_scheduler:
- # epochs: 64
- # steps_per_epoch: 1245
-
-# optimizer:
- # lr: 1.0e-3
-
summary: null
diff --git a/training/conf/lr_scheduler/cosine_annealing.yaml b/training/conf/lr_schedulers/cosine_annealing.yaml
index c53ee3a..c53ee3a 100644
--- a/training/conf/lr_scheduler/cosine_annealing.yaml
+++ b/training/conf/lr_schedulers/cosine_annealing.yaml
diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_schedulers/one_cycle.yaml
index c60577a..c60577a 100644
--- a/training/conf/lr_scheduler/one_cycle.yaml
+++ b/training/conf/lr_schedulers/one_cycle.yaml
diff --git a/training/conf/optimizer/madgrad.yaml b/training/conf/optimizers/madgrad.yaml
index a6c059d..a6c059d 100644
--- a/training/conf/optimizer/madgrad.yaml
+++ b/training/conf/optimizers/madgrad.yaml