diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-30 23:05:02 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-30 23:05:02 +0200 |
commit | e115c251a2a14508cb2f234d31bc3a6eb5cc2392 (patch) | |
tree | ba49512d3ab41fd679276e8dbbf3bb181d129780 /text_recognizer/models | |
parent | 64b263995159994e2cd37c1f657dfd4c98f182f7 (diff) |
Rename to commitment loss
Diffstat (limited to 'text_recognizer/models')
-rw-r--r-- | text_recognizer/models/vqgan.py | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py index 2f67b35..30b46a7 100644 --- a/text_recognizer/models/vqgan.py +++ b/text_recognizer/models/vqgan.py @@ -25,13 +25,13 @@ class VQGANLitModel(BaseLitModel): """Training step.""" data, _ = batch - reconstructions, vq_loss = self(data) + reconstructions, commitment_loss = self(data) if optimizer_idx == 0: loss, log = self.loss_fn( data=data, reconstructions=reconstructions, - vq_loss=vq_loss, + commitment_loss=commitment_loss, decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=optimizer_idx, global_step=self.global_step, @@ -47,7 +47,7 @@ class VQGANLitModel(BaseLitModel): loss, log = self.loss_fn( data=data, reconstructions=reconstructions, - vq_loss=vq_loss, + commitment_loss=commitment_loss, decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=optimizer_idx, global_step=self.global_step, @@ -62,12 +62,12 @@ class VQGANLitModel(BaseLitModel): def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Validation step.""" data, _ = batch - reconstructions, vq_loss = self(data) + reconstructions, commitment_loss = self(data) loss, log = self.loss_fn( data=data, reconstructions=reconstructions, - vq_loss=vq_loss, + commitment_loss=commitment_loss, decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=0, global_step=self.global_step, @@ -81,7 +81,7 @@ class VQGANLitModel(BaseLitModel): _, log = self.loss_fn( data=data, reconstructions=reconstructions, - vq_loss=vq_loss, + commitment_loss=commitment_loss, decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=1, global_step=self.global_step, @@ -92,12 +92,12 @@ class VQGANLitModel(BaseLitModel): def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" data, _ = batch - reconstructions, vq_loss = self(data) + reconstructions, commitment_loss = self(data) _, log = self.loss_fn( data=data, reconstructions=reconstructions, - vq_loss=vq_loss, + commitment_loss=commitment_loss, decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=0, global_step=self.global_step, @@ -108,7 +108,7 @@ class VQGANLitModel(BaseLitModel): _, log = self.loss_fn( data=data, reconstructions=reconstructions, - vq_loss=vq_loss, + commitment_loss=commitment_loss, decoder_last_layer=self.network.decoder.decoder[-1].weight, optimizer_idx=1, global_step=self.global_step, |