diff options
-rw-r--r-- | text_recognizer/models/vqvae.py | 34 | ||||
-rw-r--r-- | text_recognizer/networks/vqvae/quantizer.py | 18 | ||||
-rw-r--r-- | training/conf/experiment/vqgan.yaml | 60 |
3 files changed, 55 insertions, 57 deletions
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index 56229b3..92f28ad 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -11,7 +11,7 @@ from text_recognizer.models.base import BaseLitModel class VQVAELitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" - latent_loss_weight: float = attr.ib(default=0.25) + commitment: float = attr.ib(default=0.25) def forward(self, data: Tensor) -> Tensor: """Forward pass with the transformer network.""" @@ -21,37 +21,35 @@ class VQVAELitModel(BaseLitModel): """Training step.""" data, _ = batch - reconstructions, vq_loss = self(data) + reconstructions, commitment_loss = self(data) + loss = self.loss_fn(reconstructions, data) - loss = loss + self.latent_loss_weight * vq_loss + loss = loss + self.commitment * commitment_loss - self.log("train/vq_loss", vq_loss) + self.log("train/commitment_loss", commitment_loss) self.log("train/loss", loss) - - # self.train_acc(reconstructions, data) - # self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) return loss def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Validation step.""" data, _ = batch - reconstructions, vq_loss = self(data) + + reconstructions, commitment_loss = self(data) + loss = self.loss_fn(reconstructions, data) - loss = loss + self.latent_loss_weight * vq_loss + loss = loss + self.commitment * commitment_loss - self.log("val/vq_loss", vq_loss) + self.log("val/commitment_loss", commitment_loss) self.log("val/loss", loss, prog_bar=True) - # self.val_acc(reconstructions, data) - # self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) - def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" data, _ = batch - reconstructions, vq_loss = self(data) + + reconstructions, commitment_loss = self(data) + loss = self.loss_fn(reconstructions, data) - loss = loss + self.latent_loss_weight * vq_loss - self.log("test/vq_loss", vq_loss) + loss = loss + self.commitment * commitment_loss + + self.log("test/commitment_loss", commitment_loss) self.log("test/loss", loss) - # self.test_acc(reconstructions, data) - # self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) diff --git a/text_recognizer/networks/vqvae/quantizer.py b/text_recognizer/networks/vqvae/quantizer.py index 6fb57e8..bba9b60 100644 --- a/text_recognizer/networks/vqvae/quantizer.py +++ b/text_recognizer/networks/vqvae/quantizer.py @@ -34,7 +34,7 @@ class VectorQuantizer(nn.Module): self.decay = decay self.embedding = EmbeddingEMA(self.num_embeddings, self.embedding_dim) - def discretization_bottleneck(self, latent: Tensor) -> Tensor: + def _discretization_bottleneck(self, latent: Tensor) -> Tensor: """Computes the code nearest to the latent representation. First we compute the posterior categorical distribution, and then map @@ -78,11 +78,11 @@ class VectorQuantizer(nn.Module): quantized_latent, "(b h w) d -> b h w d", b=b, h=h, w=w ) if self.training: - self.compute_ema(one_hot_encoding=one_hot_encoding, latent=latent) + self._compute_ema(one_hot_encoding=one_hot_encoding, latent=latent) return quantized_latent - def compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None: + def _compute_ema(self, one_hot_encoding: Tensor, latent: Tensor) -> None: """Computes the EMA update to the codebook.""" batch_cluster_size = one_hot_encoding.sum(axis=0) batch_embedding_avg = (latent.t() @ one_hot_encoding).t() @@ -97,7 +97,7 @@ class VectorQuantizer(nn.Module): ).unsqueeze(1) self.embedding.weight.data.copy_(new_embedding) - def vq_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor: + def _commitment_loss(self, latent: Tensor, quantized_latent: Tensor) -> Tensor: """Vector Quantization loss. The vector quantization algorithm allows us to create a codebook. The VQ @@ -119,10 +119,8 @@ class VectorQuantizer(nn.Module): Tensor: The combinded VQ loss. """ - commitment_loss = F.mse_loss(quantized_latent.detach(), latent) - # embedding_loss = F.mse_loss(quantized_latent, latent.detach()) - # return embedding_loss + self.beta * commitment_loss - return commitment_loss + loss = F.mse_loss(quantized_latent.detach(), latent) + return loss def forward(self, latent: Tensor) -> Tensor: """Forward pass that returns the quantized vector and the vq loss.""" @@ -130,9 +128,9 @@ class VectorQuantizer(nn.Module): latent = rearrange(latent, "b d h w -> b h w d") # Maps latent to the nearest code in the codebook. - quantized_latent = self.discretization_bottleneck(latent) + quantized_latent = self._discretization_bottleneck(latent) - loss = self.vq_loss(latent, quantized_latent) + loss = self._commitment_loss(latent, quantized_latent) # Add residue to the quantized latent. quantized_latent = latent + (quantized_latent - latent).detach() diff --git a/training/conf/experiment/vqgan.yaml b/training/conf/experiment/vqgan.yaml index 485e963..9224bc7 100644 --- a/training/conf/experiment/vqgan.yaml +++ b/training/conf/experiment/vqgan.yaml @@ -26,36 +26,38 @@ criterion: datamodule: batch_size: 6 -lr_schedulers: - generator: - _target_: torch.optim.lr_scheduler.OneCycleLR - max_lr: 3.0e-4 - total_steps: null - epochs: 100 - steps_per_epoch: 3369 - pct_start: 0.1 - anneal_strategy: cos - cycle_momentum: true - base_momentum: 0.85 - max_momentum: 0.95 - div_factor: 1.0e3 - final_div_factor: 1.0e4 - three_phase: true - last_epoch: -1 - verbose: false - - # Non-class arguments - interval: step - monitor: val/loss - - discriminator: - _target_: torch.optim.lr_scheduler.CosineAnnealingLR - T_max: 64 - eta_min: 0.0 - last_epoch: -1 +lr_schedulers: null - interval: epoch - monitor: val/loss +# lr_schedulers: +# generator: +# _target_: torch.optim.lr_scheduler.OneCycleLR +# max_lr: 3.0e-4 +# total_steps: null +# epochs: 100 +# steps_per_epoch: 3369 +# pct_start: 0.1 +# anneal_strategy: cos +# cycle_momentum: true +# base_momentum: 0.85 +# max_momentum: 0.95 +# div_factor: 1.0e3 +# final_div_factor: 1.0e4 +# three_phase: true +# last_epoch: -1 +# verbose: false +# +# # Non-class arguments +# interval: step +# monitor: val/loss +# +# discriminator: +# _target_: torch.optim.lr_scheduler.CosineAnnealingLR +# T_max: 64 +# eta_min: 0.0 +# last_epoch: -1 +# +# interval: epoch +# monitor: val/loss optimizers: generator: |