summaryrefslogtreecommitdiff
path: root/text_recognizer/models/vqgan.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/vqgan.py')
-rw-r--r--text_recognizer/models/vqgan.py20
1 files changed, 14 insertions, 6 deletions
diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py
index 7c707b1..2f67b35 100644
--- a/text_recognizer/models/vqgan.py
+++ b/text_recognizer/models/vqgan.py
@@ -32,13 +32,13 @@ class VQGANLitModel(BaseLitModel):
data=data,
reconstructions=reconstructions,
vq_loss=vq_loss,
+ decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=optimizer_idx,
+ global_step=self.global_step,
stage="train",
)
self.log(
- "train/loss",
- loss,
- prog_bar=True,
+ "train/loss", loss, prog_bar=True,
)
self.log_dict(log, logger=True, on_step=True, on_epoch=True)
return loss
@@ -48,13 +48,13 @@ class VQGANLitModel(BaseLitModel):
data=data,
reconstructions=reconstructions,
vq_loss=vq_loss,
+ decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=optimizer_idx,
+ global_step=self.global_step,
stage="train",
)
self.log(
- "train/discriminator_loss",
- loss,
- prog_bar=True,
+ "train/discriminator_loss", loss, prog_bar=True,
)
self.log_dict(log, logger=True, on_step=True, on_epoch=True)
return loss
@@ -68,7 +68,9 @@ class VQGANLitModel(BaseLitModel):
data=data,
reconstructions=reconstructions,
vq_loss=vq_loss,
+ decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=0,
+ global_step=self.global_step,
stage="val",
)
self.log(
@@ -80,7 +82,9 @@ class VQGANLitModel(BaseLitModel):
data=data,
reconstructions=reconstructions,
vq_loss=vq_loss,
+ decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=1,
+ global_step=self.global_step,
stage="val",
)
self.log_dict(log)
@@ -94,7 +98,9 @@ class VQGANLitModel(BaseLitModel):
data=data,
reconstructions=reconstructions,
vq_loss=vq_loss,
+ decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=0,
+ global_step=self.global_step,
stage="test",
)
self.log_dict(log)
@@ -103,7 +109,9 @@ class VQGANLitModel(BaseLitModel):
data=data,
reconstructions=reconstructions,
vq_loss=vq_loss,
+ decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=1,
+ global_step=self.global_step,
stage="test",
)
self.log_dict(log)