From d3afa310f77f47553586eeee58e3d3345a754e2c Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Wed, 4 Aug 2021 05:03:51 +0200 Subject: New VQVAE --- text_recognizer/models/vqvae.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) (limited to 'text_recognizer/models/vqvae.py') diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index 22da018..5890fd9 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -14,31 +14,33 @@ from text_recognizer.models.base import BaseLitModel class VQVAELitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" + latent_loss_weight: float = attr.ib(default=0.25) + def forward(self, data: Tensor) -> Tensor: """Forward pass with the transformer network.""" - return self.network.predict(data) + return self.network(data) def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: """Training step.""" data, _ = batch - reconstructions, vq_loss = self.network(data) + reconstructions, vq_loss = self(data) loss = self.loss_fn(reconstructions, data) - loss += vq_loss + loss += self.latent_loss_weight * vq_loss self.log("train/loss", loss) return loss def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Validation step.""" data, _ = batch - reconstructions, vq_loss = self.network(data) + reconstructions, vq_loss = self(data) loss = self.loss_fn(reconstructions, data) - loss += vq_loss + loss += self.latent_loss_weight * vq_loss self.log("val/loss", loss, prog_bar=True) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" data, _ = batch - reconstructions, vq_loss = self.network(data) + reconstructions, vq_loss = self(data) loss = self.loss_fn(reconstructions, data) - loss += vq_loss + loss += self.latent_loss_weight * vq_loss self.log("test/loss", loss) -- cgit v1.2.3-70-g09d2