summaryrefslogtreecommitdiff
path: root/text_recognizer/models
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-15 21:15:31 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-15 21:15:31 +0200
commitda7d2171c818afefb3bad3cd66ce85fddd519c1c (patch)
treebc2cd9f2aeca62cc2793a6882ee96ab5033868e2 /text_recognizer/models
parent441b7484348953deb7c94150675d54583ef5a81a (diff)
Updates to VQGAN loss
Diffstat (limited to 'text_recognizer/models')
-rw-r--r--text_recognizer/models/vqgan.py20
1 files changed, 14 insertions, 6 deletions
diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py
index 7c707b1..2f67b35 100644
--- a/text_recognizer/models/vqgan.py
+++ b/text_recognizer/models/vqgan.py
@@ -32,13 +32,13 @@ class VQGANLitModel(BaseLitModel):
data=data,
reconstructions=reconstructions,
vq_loss=vq_loss,
+ decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=optimizer_idx,
+ global_step=self.global_step,
stage="train",
)
self.log(
- "train/loss",
- loss,
- prog_bar=True,
+ "train/loss", loss, prog_bar=True,
)
self.log_dict(log, logger=True, on_step=True, on_epoch=True)
return loss
@@ -48,13 +48,13 @@ class VQGANLitModel(BaseLitModel):
data=data,
reconstructions=reconstructions,
vq_loss=vq_loss,
+ decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=optimizer_idx,
+ global_step=self.global_step,
stage="train",
)
self.log(
- "train/discriminator_loss",
- loss,
- prog_bar=True,
+ "train/discriminator_loss", loss, prog_bar=True,
)
self.log_dict(log, logger=True, on_step=True, on_epoch=True)
return loss
@@ -68,7 +68,9 @@ class VQGANLitModel(BaseLitModel):
data=data,
reconstructions=reconstructions,
vq_loss=vq_loss,
+ decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=0,
+ global_step=self.global_step,
stage="val",
)
self.log(
@@ -80,7 +82,9 @@ class VQGANLitModel(BaseLitModel):
data=data,
reconstructions=reconstructions,
vq_loss=vq_loss,
+ decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=1,
+ global_step=self.global_step,
stage="val",
)
self.log_dict(log)
@@ -94,7 +98,9 @@ class VQGANLitModel(BaseLitModel):
data=data,
reconstructions=reconstructions,
vq_loss=vq_loss,
+ decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=0,
+ global_step=self.global_step,
stage="test",
)
self.log_dict(log)
@@ -103,7 +109,9 @@ class VQGANLitModel(BaseLitModel):
data=data,
reconstructions=reconstructions,
vq_loss=vq_loss,
+ decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=1,
+ global_step=self.global_step,
stage="test",
)
self.log_dict(log)