diff options
Diffstat (limited to 'text_recognizer/models')
-rw-r--r-- | text_recognizer/models/vqgan.py | 20 |
1 files changed, 14 insertions, 6 deletions
diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py index 7c707b1..2f67b35 100644 --- a/text_recognizer/models/vqgan.py +++ b/text_recognizer/models/vqgan.py @@ -32,13 +32,13 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=optimizer_idx, + global_step=self.global_step, stage="train", ) self.log( - "train/loss", - loss, - prog_bar=True, + "train/loss", loss, prog_bar=True, ) self.log_dict(log, logger=True, on_step=True, on_epoch=True) return loss @@ -48,13 +48,13 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=optimizer_idx, + global_step=self.global_step, stage="train", ) self.log( - "train/discriminator_loss", - loss, - prog_bar=True, + "train/discriminator_loss", loss, prog_bar=True, ) self.log_dict(log, logger=True, on_step=True, on_epoch=True) return loss @@ -68,7 +68,9 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=0, + global_step=self.global_step, stage="val", ) self.log( @@ -80,7 +82,9 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=1, + global_step=self.global_step, stage="val", ) self.log_dict(log) @@ -94,7 +98,9 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=0, + global_step=self.global_step, stage="test", ) self.log_dict(log) @@ -103,7 +109,9 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=1, + global_step=self.global_step, stage="test", ) self.log_dict(log) |