diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-15 21:15:31 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-15 21:15:31 +0200 |
commit | da7d2171c818afefb3bad3cd66ce85fddd519c1c (patch) | |
tree | bc2cd9f2aeca62cc2793a6882ee96ab5033868e2 /text_recognizer/models | |
parent | 441b7484348953deb7c94150675d54583ef5a81a (diff) |
Updates to VQGAN loss
Diffstat (limited to 'text_recognizer/models')
-rw-r--r-- | text_recognizer/models/vqgan.py | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py index 7c707b1..2f67b35 100644 --- a/text_recognizer/models/vqgan.py +++ b/text_recognizer/models/vqgan.py @@ -32,13 +32,13 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=optimizer_idx, + global_step=self.global_step, stage="train", ) self.log( - "train/loss", - loss, - prog_bar=True, + "train/loss", loss, prog_bar=True, ) self.log_dict(log, logger=True, on_step=True, on_epoch=True) return loss @@ -48,13 +48,13 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=optimizer_idx, + global_step=self.global_step, stage="train", ) self.log( - "train/discriminator_loss", - loss, - prog_bar=True, + "train/discriminator_loss", loss, prog_bar=True, ) self.log_dict(log, logger=True, on_step=True, on_epoch=True) return loss @@ -68,7 +68,9 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=0, + global_step=self.global_step, stage="val", ) self.log( @@ -80,7 +82,9 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=1, + global_step=self.global_step, stage="val", ) self.log_dict(log) @@ -94,7 +98,9 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=0, + global_step=self.global_step, stage="test", ) self.log_dict(log) @@ -103,7 +109,9 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=1, + global_step=self.global_step, stage="test", ) self.log_dict(log) |