From e115c251a2a14508cb2f234d31bc3a6eb5cc2392 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Thu, 30 Sep 2021 23:05:02 +0200 Subject: Rename to commitment loss --- text_recognizer/models/vqgan.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) (limited to 'text_recognizer') diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py index 2f67b35..30b46a7 100644 --- a/text_recognizer/models/vqgan.py +++ b/text_recognizer/models/vqgan.py @@ -25,13 +25,13 @@ class VQGANLitModel(BaseLitModel): """Training step.""" data, _ = batch - reconstructions, vq_loss = self(data) + reconstructions, commitment_loss = self(data) if optimizer_idx == 0: loss, log = self.loss_fn( data=data, reconstructions=reconstructions, - vq_loss=vq_loss, + commitment_loss=commitment_loss, decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=optimizer_idx, global_step=self.global_step, @@ -47,7 +47,7 @@ class VQGANLitModel(BaseLitModel): loss, log = self.loss_fn( data=data, reconstructions=reconstructions, - vq_loss=vq_loss, + commitment_loss=commitment_loss, decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=optimizer_idx, global_step=self.global_step, @@ -62,12 +62,12 @@ class VQGANLitModel(BaseLitModel): def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Validation step.""" data, _ = batch - reconstructions, vq_loss = self(data) + reconstructions, commitment_loss = self(data) loss, log = self.loss_fn( data=data, reconstructions=reconstructions, - vq_loss=vq_loss, + commitment_loss=commitment_loss, decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=0, global_step=self.global_step, @@ -81,7 +81,7 @@ class VQGANLitModel(BaseLitModel): _, log = self.loss_fn( data=data, reconstructions=reconstructions, - vq_loss=vq_loss, + commitment_loss=commitment_loss, decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=1, global_step=self.global_step, @@ -92,12 +92,12 @@ class VQGANLitModel(BaseLitModel): def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" data, _ = batch - reconstructions, vq_loss = self(data) + reconstructions, commitment_loss = self(data) _, log = self.loss_fn( data=data, reconstructions=reconstructions, - vq_loss=vq_loss, + commitment_loss=commitment_loss, decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=0, global_step=self.global_step, @@ -108,7 +108,7 @@ class VQGANLitModel(BaseLitModel): _, log = self.loss_fn( data=data, reconstructions=reconstructions, - vq_loss=vq_loss, + commitment_loss=commitment_loss, decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=1, global_step=self.global_step, -- cgit v1.2.3-70-g09d2