From 82f4acabe24e5171c40afa2939a4777ba87bcc30 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 8 Aug 2021 21:43:39 +0200 Subject: Add training of VQGAN --- text_recognizer/criterions/vqgan_loss.py | 23 ++++++++--------------- text_recognizer/models/base.py | 8 +++++--- text_recognizer/models/vqgan.py | 24 ++---------------------- 3 files changed, 15 insertions(+), 40 deletions(-) (limited to 'text_recognizer') diff --git a/text_recognizer/criterions/vqgan_loss.py b/text_recognizer/criterions/vqgan_loss.py index 8bb568f..87f0f1c 100644 --- a/text_recognizer/criterions/vqgan_loss.py +++ b/text_recognizer/criterions/vqgan_loss.py @@ -1,5 +1,5 @@ """VQGAN loss for PyTorch Lightning.""" -from typing import Dict +from typing import Dict, Optional from click.types import Tuple import torch @@ -40,9 +40,9 @@ class VQGANLoss(nn.Module): vq_loss: Tensor, optimizer_idx: int, stage: str, - ) -> Tuple[Tensor, Dict[str, Tensor]]: + ) -> Optional[Tuple]: """Calculates the VQGAN loss.""" - rec_loss = self.reconstruction_loss( + rec_loss: Tensor = self.reconstruction_loss( data.contiguous(), reconstructions.contiguous() ) @@ -51,13 +51,13 @@ class VQGANLoss(nn.Module): logits_fake = self.discriminator(reconstructions.contiguous()) g_loss = -torch.mean(logits_fake) - loss = ( + loss: Tensor = ( rec_loss + self.discriminator_weight * g_loss + self.vq_loss_weight * vq_loss ) log = { - f"{stage}/loss": loss, + f"{stage}/total_loss": loss, f"{stage}/vq_loss": vq_loss, f"{stage}/rec_loss": rec_loss, f"{stage}/g_loss": g_loss, @@ -68,18 +68,11 @@ class VQGANLoss(nn.Module): logits_fake = self.discriminator(reconstructions.contiguous().detach()) logits_real = self.discriminator(data.contiguous().detach()) - d_loss = self.adversarial_loss( + d_loss = self.discriminator_weight * self.adversarial_loss( logits_real=logits_real, logits_fake=logits_fake ) - loss = ( - rec_loss - + self.discriminator_weight * d_loss - + self.vq_loss_weight * vq_loss - ) + log = { - f"{stage}/loss": loss, - f"{stage}/vq_loss": vq_loss, - f"{stage}/rec_loss": rec_loss, f"{stage}/d_loss": d_loss, } - return loss, log + return d_loss, log diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 8b68ed9..94dbde5 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -49,12 +49,14 @@ class BaseLitModel(LightningModule): """Configures the optimizer.""" optimizers = [] for optimizer_config in self.optimizer_configs.values(): - network = getattr(self, optimizer_config.parameters) + module = self + for m in str(optimizer_config.parameters).split("."): + module = getattr(module, m) del optimizer_config.parameters log.info(f"Instantiating optimizer <{optimizer_config._target_}>") optimizers.append( hydra.utils.instantiate( - self.optimizer_config, params=network.parameters() + optimizer_config, params=module.parameters() ) ) return optimizers @@ -92,7 +94,7 @@ class BaseLitModel(LightningModule): ) -> Tuple[List[Type[torch.optim.Optimizer]], List[Dict[str, Any]]]: """Configures optimizer and lr scheduler.""" optimizers = self._configure_optimizer() - schedulers = self._configure_lr_scheduler(optimizers) + schedulers = self._configure_lr_schedulers(optimizers) return optimizers, schedulers def forward(self, data: Tensor) -> Tensor: diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py index 8ff65cc..80653b6 100644 --- a/text_recognizer/models/vqgan.py +++ b/text_recognizer/models/vqgan.py @@ -9,7 +9,7 @@ from text_recognizer.criterions.vqgan_loss import VQGANLoss @attr.s(auto_attribs=True, eq=False) -class VQVAELitModel(BaseLitModel): +class VQGANLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" loss_fn: VQGANLoss = attr.ib() @@ -26,7 +26,6 @@ class VQVAELitModel(BaseLitModel): data, _ = batch reconstructions, vq_loss = self(data) - loss = self.loss_fn(reconstructions, data) if optimizer_idx == 0: loss, log = self.loss_fn( @@ -81,14 +80,6 @@ class VQVAELitModel(BaseLitModel): self.log( "val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True ) - self.log( - "val/rec_loss", - log["val/rec_loss"], - prog_bar=True, - logger=True, - on_step=True, - on_epoch=True, - ) self.log_dict(log) _, log = self.loss_fn( @@ -105,24 +96,13 @@ class VQVAELitModel(BaseLitModel): data, _ = batch reconstructions, vq_loss = self(data) - loss, log = self.loss_fn( + _, log = self.loss_fn( data=data, reconstructions=reconstructions, vq_loss=vq_loss, optimizer_idx=0, stage="test", ) - self.log( - "test/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True - ) - self.log( - "test/rec_loss", - log["test/rec_loss"], - prog_bar=True, - logger=True, - on_step=True, - on_epoch=True, - ) self.log_dict(log) _, log = self.loss_fn( -- cgit v1.2.3-70-g09d2