From da7d2171c818afefb3bad3cd66ce85fddd519c1c Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 15 Aug 2021 21:15:31 +0200 Subject: Updates to VQGAN loss --- text_recognizer/models/vqgan.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) (limited to 'text_recognizer/models') diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py index 7c707b1..2f67b35 100644 --- a/text_recognizer/models/vqgan.py +++ b/text_recognizer/models/vqgan.py @@ -32,13 +32,13 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=optimizer_idx, + global_step=self.global_step, stage="train", ) self.log( - "train/loss", - loss, - prog_bar=True, + "train/loss", loss, prog_bar=True, ) self.log_dict(log, logger=True, on_step=True, on_epoch=True) return loss @@ -48,13 +48,13 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=optimizer_idx, + global_step=self.global_step, stage="train", ) self.log( - "train/discriminator_loss", - loss, - prog_bar=True, + "train/discriminator_loss", loss, prog_bar=True, ) self.log_dict(log, logger=True, on_step=True, on_epoch=True) return loss @@ -68,7 +68,9 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=0, + global_step=self.global_step, stage="val", ) self.log( @@ -80,7 +82,9 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=1, + global_step=self.global_step, stage="val", ) self.log_dict(log) @@ -94,7 +98,9 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=0, + global_step=self.global_step, stage="test", ) self.log_dict(log) @@ -103,7 +109,9 @@ class VQGANLitModel(BaseLitModel): data=data, reconstructions=reconstructions, vq_loss=vq_loss, + decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=1, + global_step=self.global_step, stage="test", ) self.log_dict(log) -- cgit v1.2.3-70-g09d2