diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/models/vqvae.py | 34 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/quantizer.py | 18 |
2 files changed, 24 insertions, 28 deletions
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index 56229b3..92f28ad 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -11,7 +11,7 @@ from text_recognizer.models.base import BaseLitModel class VQVAELitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - latent_loss_weight: float = attr.ib(default=0.25) + commitment: float = attr.ib(default=0.25) def forward(self, data: Tensor) -> Tensor: """Forward pass with the transformer network.""" @@ -21,37 +21,35 @@ class VQVAELitModel(BaseLitModel): """Training step.""" data, _ = batch - reconstructions, vq_loss = self(data) + reconstructions, commitment_loss = self(data) + loss = self.loss_fn(reconstructions, data) - loss = loss + self.latent_loss_weight * vq_loss + loss = loss + self.commitment * commitment_loss - self.log("train/vq_loss", vq_loss) + self.log("train/commitment_loss", commitment_loss) self.log("train/loss", loss) - - # self.train_acc(reconstructions, data) - # self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) return loss def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Validation step.""" data, _ = batch - reconstructions, vq_loss = self(data) + + reconstructions, commitment_loss = self(data) + loss = self.loss_fn(reconstructions, data) - loss = loss + self.latent_loss_weight * vq_loss + loss = loss + self.commitment * commitment_loss - self.log("val/vq_loss", vq_loss) + self.log("val/commitment_loss", commitment_loss) self.log("val/loss", loss, prog_bar=True) - # self.val_acc(reconstructions, data) - # self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) - def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" data, _ = batch - reconstructions, vq_loss = self(data) + + reconstructions, commitment_loss = self(data) + loss = self.loss_fn(reconstructions, data) - loss = loss + self.latent_loss_weight * vq_loss - self.log("test/vq_loss", vq_loss) + loss = loss + self.commitment * commitment_loss + + self.log("test/commitment_loss", commitment_loss) self.log("test/loss", loss) - # self.test_acc(reconstructions, data) - # self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) diff --git a/text_recognizer/networks/vqvae/quantizer.py b/text_recognizer/networks/vqvae/quantizer.py index 6fb57e8..bba9b60 100644 --- a/text_recognizer/networks/vqvae/quantizer.py +++ b/text_recognizer/networks/vqvae/quantizer.py @@ -34,7 +34,7 @@ class VectorQuantizer(nn.Module): self.decay = decay self.embedding = EmbeddingEMA(self.num_embeddings, self.embedding_dim) - def discretization_bottleneck(self, latent: Tensor) -> Tensor: + def _discretization_bottleneck(self, latent: Tensor) -> Tensor: """Computes the code nearest to the latent representation. First we compute the posterior categorical distribution, and then map @@ -78,11 +78,11 @@ class VectorQuantizer(nn.Module): quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w ) if self.training: - self.compute_ema(one_hot_encoding=one_hot_encoding, latent=latent) + self._compute_ema(one_hot_encoding=one_hot_encoding, latent=latent) return quantized_latent - def compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None: + def _compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None: """Computes the EMA update to the codebook.""" batch_cluster_size = one_hot_encoding.sum(axis=0) batch_embedding_avg = (latent.t() @ one_hot_encoding).t() @@ -97,7 +97,7 @@ class VectorQuantizer(nn.Module): ).unsqueeze(1) self.embedding.weight.data.copy_(new_embedding) - def vq_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor: + def _commitment_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor: """Vector Quantization loss. The vector quantization algorithm allows us to create a codebook. The VQ @@ -119,10 +119,8 @@ class VectorQuantizer(nn.Module): Tensor: The combinded VQ loss. """ - commitment_loss = F.mse_loss(quantized_latent.detach(), latent) - # embedding_loss = F.mse_loss(quantized_latent, latent.detach()) - # return embedding_loss + self.beta * commitment_loss - return commitment_loss + loss = F.mse_loss(quantized_latent.detach(), latent) + return loss def forward(self, latent: Tensor) -> Tensor: """Forward pass that returns the quantized vector and the vq loss.""" @@ -130,9 +128,9 @@ class VectorQuantizer(nn.Module): latent = rearrange(latent, "b d h w -> b h w d") # Maps latent to the nearest code in the codebook. - quantized_latent = self.discretization_bottleneck(latent) + quantized_latent = self._discretization_bottleneck(latent) - loss = self.vq_loss(latent, quantized_latent) + loss = self._commitment_loss(latent, quantized_latent) # Add residue to the quantized latent. quantized_latent = latent + (quantized_latent - latent).detach() |