summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/models/vqgan.py18
1 files changed, 9 insertions, 9 deletions
diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py
index 2f67b35..30b46a7 100644
--- a/text_recognizer/models/vqgan.py
+++ b/text_recognizer/models/vqgan.py
@@ -25,13 +25,13 @@ class VQGANLitModel(BaseLitModel):
"""Training step."""
data, _ = batch
- reconstructions, vq_loss = self(data)
+ reconstructions, commitment_loss = self(data)
if optimizer_idx == 0:
loss, log = self.loss_fn(
data=data,
reconstructions=reconstructions,
- vq_loss=vq_loss,
+ commitment_loss=commitment_loss,
decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=optimizer_idx,
global_step=self.global_step,
@@ -47,7 +47,7 @@ class VQGANLitModel(BaseLitModel):
loss, log = self.loss_fn(
data=data,
reconstructions=reconstructions,
- vq_loss=vq_loss,
+ commitment_loss=commitment_loss,
decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=optimizer_idx,
global_step=self.global_step,
@@ -62,12 +62,12 @@ class VQGANLitModel(BaseLitModel):
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Validation step."""
data, _ = batch
- reconstructions, vq_loss = self(data)
+ reconstructions, commitment_loss = self(data)
loss, log = self.loss_fn(
data=data,
reconstructions=reconstructions,
- vq_loss=vq_loss,
+ commitment_loss=commitment_loss,
decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=0,
global_step=self.global_step,
@@ -81,7 +81,7 @@ class VQGANLitModel(BaseLitModel):
_, log = self.loss_fn(
data=data,
reconstructions=reconstructions,
- vq_loss=vq_loss,
+ commitment_loss=commitment_loss,
decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=1,
global_step=self.global_step,
@@ -92,12 +92,12 @@ class VQGANLitModel(BaseLitModel):
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Test step."""
data, _ = batch
- reconstructions, vq_loss = self(data)
+ reconstructions, commitment_loss = self(data)
_, log = self.loss_fn(
data=data,
reconstructions=reconstructions,
- vq_loss=vq_loss,
+ commitment_loss=commitment_loss,
decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=0,
global_step=self.global_step,
@@ -108,7 +108,7 @@ class VQGANLitModel(BaseLitModel):
_, log = self.loss_fn(
data=data,
reconstructions=reconstructions,
- vq_loss=vq_loss,
+ commitment_loss=commitment_loss,
decoder_last_layer=self.network.decoder.decoder[-1].weight,
optimizer_idx=1,
global_step=self.global_step,