From da7d2171c818afefb3bad3cd66ce85fddd519c1c Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sun, 15 Aug 2021 21:15:31 +0200
Subject: Updates to VQGAN loss

---
 text_recognizer/models/vqgan.py | 20 ++++++++++++++------
 1 file changed, 14 insertions(+), 6 deletions(-)

(limited to 'text_recognizer/models')

diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py
index 7c707b1..2f67b35 100644
--- a/text_recognizer/models/vqgan.py
+++ b/text_recognizer/models/vqgan.py
@@ -32,13 +32,13 @@ class VQGANLitModel(BaseLitModel):
                 data=data,
                 reconstructions=reconstructions,
                 vq_loss=vq_loss,
+                decoder_last_layer=self.network.decoder.decoder[-1].weight,
                 optimizer_idx=optimizer_idx,
+                global_step=self.global_step,
                 stage="train",
             )
             self.log(
-                "train/loss",
-                loss,
-                prog_bar=True,
+                "train/loss", loss, prog_bar=True,
             )
             self.log_dict(log, logger=True, on_step=True, on_epoch=True)
             return loss
@@ -48,13 +48,13 @@ class VQGANLitModel(BaseLitModel):
                 data=data,
                 reconstructions=reconstructions,
                 vq_loss=vq_loss,
+                decoder_last_layer=self.network.decoder.decoder[-1].weight,
                 optimizer_idx=optimizer_idx,
+                global_step=self.global_step,
                 stage="train",
             )
             self.log(
-                "train/discriminator_loss",
-                loss,
-                prog_bar=True,
+                "train/discriminator_loss", loss, prog_bar=True,
             )
             self.log_dict(log, logger=True, on_step=True, on_epoch=True)
             return loss
@@ -68,7 +68,9 @@ class VQGANLitModel(BaseLitModel):
             data=data,
             reconstructions=reconstructions,
             vq_loss=vq_loss,
+            decoder_last_layer=self.network.decoder.decoder[-1].weight,
             optimizer_idx=0,
+            global_step=self.global_step,
             stage="val",
         )
         self.log(
@@ -80,7 +82,9 @@ class VQGANLitModel(BaseLitModel):
             data=data,
             reconstructions=reconstructions,
             vq_loss=vq_loss,
+            decoder_last_layer=self.network.decoder.decoder[-1].weight,
             optimizer_idx=1,
+            global_step=self.global_step,
             stage="val",
         )
         self.log_dict(log)
@@ -94,7 +98,9 @@ class VQGANLitModel(BaseLitModel):
             data=data,
             reconstructions=reconstructions,
             vq_loss=vq_loss,
+            decoder_last_layer=self.network.decoder.decoder[-1].weight,
             optimizer_idx=0,
+            global_step=self.global_step,
             stage="test",
         )
         self.log_dict(log)
@@ -103,7 +109,9 @@ class VQGANLitModel(BaseLitModel):
             data=data,
             reconstructions=reconstructions,
             vq_loss=vq_loss,
+            decoder_last_layer=self.network.decoder.decoder[-1].weight,
             optimizer_idx=1,
+            global_step=self.global_step,
             stage="test",
         )
         self.log_dict(log)
-- 
cgit v1.2.3-70-g09d2