diff options
-rw-r--r-- | text_recognizer/models/vqgan.py | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py index 2f67b35..30b46a7 100644 --- a/text_recognizer/models/vqgan.py +++ b/text_recognizer/models/vqgan.py @@ -25,13 +25,13 @@ class VQGANLitModel(BaseLitModel): """Training step.""" data, _ = batch - reconstructions, vq_loss = self(data) + reconstructions, commitment_loss = self(data) if optimizer_idx == 0: loss, log = self.loss_fn( data=data, reconstructions=reconstructions, - vq_loss=vq_loss, + commitment_loss=commitment_loss, decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=optimizer_idx, global_step=self.global_step, @@ -47,7 +47,7 @@ class VQGANLitModel(BaseLitModel): loss, log = self.loss_fn( data=data, reconstructions=reconstructions, - vq_loss=vq_loss, + commitment_loss=commitment_loss, decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=optimizer_idx, global_step=self.global_step, @@ -62,12 +62,12 @@ class VQGANLitModel(BaseLitModel): def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Validation step.""" data, _ = batch - reconstructions, vq_loss = self(data) + reconstructions, commitment_loss = self(data) loss, log = self.loss_fn( data=data, reconstructions=reconstructions, - vq_loss=vq_loss, + commitment_loss=commitment_loss, decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=0, global_step=self.global_step, @@ -81,7 +81,7 @@ class VQGANLitModel(BaseLitModel): _, log = self.loss_fn( data=data, reconstructions=reconstructions, - vq_loss=vq_loss, + commitment_loss=commitment_loss, decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=1, global_step=self.global_step, @@ -92,12 +92,12 @@ class VQGANLitModel(BaseLitModel): def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" data, _ = batch - reconstructions, vq_loss = self(data) + reconstructions, commitment_loss = self(data) _, log = self.loss_fn( data=data, reconstructions=reconstructions, - vq_loss=vq_loss, + commitment_loss=commitment_loss, decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=0, global_step=self.global_step, @@ -108,7 +108,7 @@ class VQGANLitModel(BaseLitModel): _, log = self.loss_fn( data=data, reconstructions=reconstructions, - vq_loss=vq_loss, + commitment_loss=commitment_loss, decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=1, global_step=self.global_step, |