summaryrefslogtreecommitdiff
path: root/text_recognizer/models/vqgan.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-30 23:05:02 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-30 23:05:02 +0200
commite115c251a2a14508cb2f234d31bc3a6eb5cc2392 (patch)
treeba49512d3ab41fd679276e8dbbf3bb181d129780 /text_recognizer/models/vqgan.py
parent64b263995159994e2cd37c1f657dfd4c98f182f7 (diff)
Rename to commitment loss
Diffstat (limited to 'text_recognizer/models/vqgan.py')
-rw-r--r--text_recognizer/models/vqgan.py18
1 files changed, 9 insertions, 9 deletions
diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py
index 2f67b35..30b46a7 100644
--- a/text_recognizer/models/vqgan.py
+++ b/text_recognizer/models/vqgan.py
@@ -25,13 +25,13 @@ class VQGANLitModel(BaseLitModel):
"""Training step."""
data, _ = batch
- reconstructions, vq_loss = self(data)
+ reconstructions, commitment_loss = self(data)
if optimizer_idx == 0:
loss, log = self.loss_fn(
data=data,
reconstructions=reconstructions,
- vq_loss=vq_loss,
+ commitment_loss=commitment_loss,
decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=optimizer_idx,
global_step=self.global_step,
@@ -47,7 +47,7 @@ class VQGANLitModel(BaseLitModel):
loss, log = self.loss_fn(
data=data,
reconstructions=reconstructions,
- vq_loss=vq_loss,
+ commitment_loss=commitment_loss,
decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=optimizer_idx,
global_step=self.global_step,
@@ -62,12 +62,12 @@ class VQGANLitModel(BaseLitModel):
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Validation step."""
data, _ = batch
- reconstructions, vq_loss = self(data)
+ reconstructions, commitment_loss = self(data)
loss, log = self.loss_fn(
data=data,
reconstructions=reconstructions,
- vq_loss=vq_loss,
+ commitment_loss=commitment_loss,
decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=0,
global_step=self.global_step,
@@ -81,7 +81,7 @@ class VQGANLitModel(BaseLitModel):
_, log = self.loss_fn(
data=data,
reconstructions=reconstructions,
- vq_loss=vq_loss,
+ commitment_loss=commitment_loss,
decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=1,
global_step=self.global_step,
@@ -92,12 +92,12 @@ class VQGANLitModel(BaseLitModel):
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Test step."""
data, _ = batch
- reconstructions, vq_loss = self(data)
+ reconstructions, commitment_loss = self(data)
_, log = self.loss_fn(
data=data,
reconstructions=reconstructions,
- vq_loss=vq_loss,
+ commitment_loss=commitment_loss,
decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=0,
global_step=self.global_step,
@@ -108,7 +108,7 @@ class VQGANLitModel(BaseLitModel):
_, log = self.loss_fn(
data=data,
reconstructions=reconstructions,
- vq_loss=vq_loss,
+ commitment_loss=commitment_loss,
decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=1,
global_step=self.global_step,