From 240f5e9f20032e82515fa66ce784619527d1041e Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 8 Aug 2021 19:59:55 +0200 Subject: Add VQGAN and loss function --- .../criterions/n_layer_discriminator.py | 58 +++++++++ text_recognizer/criterions/vqgan_loss.py | 85 +++++++++++++ text_recognizer/data/emnist_mapping.py | 6 +- text_recognizer/models/base.py | 81 ++++++++----- text_recognizer/models/transformer.py | 4 + text_recognizer/models/vqgan.py | 135 +++++++++++++++++++++ text_recognizer/models/vqvae.py | 12 +- text_recognizer/networks/vqvae/decoder.py | 13 +- text_recognizer/networks/vqvae/encoder.py | 11 +- text_recognizer/networks/vqvae/norm.py | 6 +- text_recognizer/networks/vqvae/pixelcnn.py | 12 +- text_recognizer/networks/vqvae/residual.py | 8 +- text_recognizer/networks/vqvae/vqvae.py | 1 - 13 files changed, 376 insertions(+), 56 deletions(-) create mode 100644 text_recognizer/criterions/n_layer_discriminator.py create mode 100644 text_recognizer/criterions/vqgan_loss.py create mode 100644 text_recognizer/models/vqgan.py (limited to 'text_recognizer') diff --git a/text_recognizer/criterions/n_layer_discriminator.py b/text_recognizer/criterions/n_layer_discriminator.py new file mode 100644 index 0000000..e5f8449 --- /dev/null +++ b/text_recognizer/criterions/n_layer_discriminator.py @@ -0,0 +1,58 @@ +"""Pix2pix discriminator loss.""" +from torch import nn, Tensor + +from text_recognizer.networks.vqvae.norm import Normalize + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator loss in Pix2Pix.""" + + def __init__( + self, in_channels: int = 1, num_channels: int = 32, num_layers: int = 3 + ) -> None: + super().__init__() + self.in_channels = in_channels + self.num_channels = num_channels + self.num_layers = num_layers + self.discriminator = self._build_discriminator() + + def _build_discriminator(self) -> nn.Sequential: + """Builds discriminator.""" + discriminator = [ + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.num_channels, + kernel_size=4, + stride=2, + padding=1, + ), + nn.Mish(inplace=True), + ] + in_channels = self.num_channels + for n in range(1, self.num_layers): + discriminator += [ + nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels * n, + kernel_size=4, + stride=2, + padding=1, + ), + Normalize(num_channels=in_channels * n), + nn.Mish(inplace=True), + ] + in_channels *= n + + discriminator += [ + nn.Conv2d( + in_channels=self.num_channels * (self.num_layers - 1), + out_channels=1, + kernel_size=4, + padding=1, + ) + ] + return nn.Sequential(*discriminator) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass through discriminator.""" + return self.discriminator(x) diff --git a/text_recognizer/criterions/vqgan_loss.py b/text_recognizer/criterions/vqgan_loss.py new file mode 100644 index 0000000..8bb568f --- /dev/null +++ b/text_recognizer/criterions/vqgan_loss.py @@ -0,0 +1,85 @@ +"""VQGAN loss for PyTorch Lightning.""" +from typing import Dict +from click.types import Tuple + +import torch +from torch import nn, Tensor +import torch.nn.functional as F + +from text_recognizer.criterions.n_layer_discriminator import NLayerDiscriminator + + +class VQGANLoss(nn.Module): + """VQGAN loss.""" + + def __init__( + self, + reconstruction_loss: nn.L1Loss, + discriminator: NLayerDiscriminator, + vq_loss_weight: float = 1.0, + discriminator_weight: float = 1.0, + ) -> None: + super().__init__() + self.reconstruction_loss = reconstruction_loss + self.discriminator = discriminator + self.vq_loss_weight = vq_loss_weight + self.discriminator_weight = discriminator_weight + + @staticmethod + def adversarial_loss(logits_real: Tensor, logits_fake: Tensor) -> Tensor: + """Calculates the adversarial loss.""" + loss_real = torch.mean(F.relu(1.0 - logits_real)) + loss_fake = torch.mean(F.relu(1.0 + logits_fake)) + d_loss = (loss_real + loss_fake) / 2.0 + return d_loss + + def forward( + self, + data: Tensor, + reconstructions: Tensor, + vq_loss: Tensor, + optimizer_idx: int, + stage: str, + ) -> Tuple[Tensor, Dict[str, Tensor]]: + """Calculates the VQGAN loss.""" + rec_loss = self.reconstruction_loss( + data.contiguous(), reconstructions.contiguous() + ) + + # GAN part. + if optimizer_idx == 0: + logits_fake = self.discriminator(reconstructions.contiguous()) + g_loss = -torch.mean(logits_fake) + + loss = ( + rec_loss + + self.discriminator_weight * g_loss + + self.vq_loss_weight * vq_loss + ) + log = { + f"{stage}/loss": loss, + f"{stage}/vq_loss": vq_loss, + f"{stage}/rec_loss": rec_loss, + f"{stage}/g_loss": g_loss, + } + return loss, log + + if optimizer_idx == 1: + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + logits_real = self.discriminator(data.contiguous().detach()) + + d_loss = self.adversarial_loss( + logits_real=logits_real, logits_fake=logits_fake + ) + loss = ( + rec_loss + + self.discriminator_weight * d_loss + + self.vq_loss_weight * vq_loss + ) + log = { + f"{stage}/loss": loss, + f"{stage}/vq_loss": vq_loss, + f"{stage}/rec_loss": rec_loss, + f"{stage}/d_loss": d_loss, + } + return loss, log diff --git a/text_recognizer/data/emnist_mapping.py b/text_recognizer/data/emnist_mapping.py index 3e91594..4406db7 100644 --- a/text_recognizer/data/emnist_mapping.py +++ b/text_recognizer/data/emnist_mapping.py @@ -9,7 +9,9 @@ from text_recognizer.data.emnist import emnist_mapping class EmnistMapping(AbstractMapping): - def __init__(self, extra_symbols: Optional[Set[str]] = None, lower: bool = True) -> None: + def __init__( + self, extra_symbols: Optional[Set[str]] = None, lower: bool = True + ) -> None: self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None self.mapping, self.inverse_mapping, self.input_size = emnist_mapping( self.extra_symbols @@ -20,10 +22,12 @@ class EmnistMapping(AbstractMapping): def _to_lower(self) -> None: """Converts mapping to lowercase letters only.""" + def _filter(x: int) -> int: if 40 <= x: return x - 26 return x + self.inverse_mapping = {v: _filter(k) for k, v in enumerate(self.mapping)} self.mapping = [c for c in self.mapping if not c.isupper()] diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index ab3fa35..8b68ed9 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -24,8 +24,8 @@ class BaseLitModel(LightningModule): network: Type[nn.Module] = attr.ib() mapping: Type[AbstractMapping] = attr.ib() loss_fn: Type[nn.Module] = attr.ib() - optimizer_config: DictConfig = attr.ib() - lr_scheduler_config: DictConfig = attr.ib() + optimizer_configs: DictConfig = attr.ib() + lr_scheduler_configs: DictConfig = attr.ib() train_acc: torchmetrics.Accuracy = attr.ib( init=False, default=torchmetrics.Accuracy() ) @@ -45,40 +45,55 @@ class BaseLitModel(LightningModule): ) -> None: optimizer.zero_grad(set_to_none=True) - def _configure_optimizer(self) -> Type[torch.optim.Optimizer]: + def _configure_optimizer(self) -> List[Type[torch.optim.Optimizer]]: """Configures the optimizer.""" - log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>") - return hydra.utils.instantiate( - self.optimizer_config, params=self.network.parameters() - ) - - def _configure_lr_scheduler( - self, optimizer: Type[torch.optim.Optimizer] - ) -> Dict[str, Any]: + optimizers = [] + for optimizer_config in self.optimizer_configs.values(): + network = getattr(self, optimizer_config.parameters) + del optimizer_config.parameters + log.info(f"Instantiating optimizer <{optimizer_config._target_}>") + optimizers.append( + hydra.utils.instantiate( + self.optimizer_config, params=network.parameters() + ) + ) + return optimizers + + def _configure_lr_schedulers( + self, optimizers: List[Type[torch.optim.Optimizer]] + ) -> List[Dict[str, Any]]: """Configures the lr scheduler.""" - # Extract non-class arguments. - monitor = self.lr_scheduler_config.monitor - interval = self.lr_scheduler_config.interval - del self.lr_scheduler_config.monitor - del self.lr_scheduler_config.interval - - log.info( - f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>" - ) - scheduler = { - "monitor": monitor, - "interval": interval, - "scheduler": hydra.utils.instantiate( - self.lr_scheduler_config, optimizer=optimizer - ), - } - return scheduler - - def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]: + schedulers = [] + for optimizer, lr_scheduler_config in zip( + optimizers, self.lr_scheduler_configs.values() + ): + # Extract non-class arguments. + monitor = lr_scheduler_config.monitor + interval = lr_scheduler_config.interval + del lr_scheduler_config.monitor + del lr_scheduler_config.interval + + log.info( + f"Instantiating learning rate scheduler <{lr_scheduler_config._target_}>" + ) + scheduler = { + "monitor": monitor, + "interval": interval, + "scheduler": hydra.utils.instantiate( + lr_scheduler_config, optimizer=optimizer + ), + } + schedulers.append(scheduler) + + return schedulers + + def configure_optimizers( + self, + ) -> Tuple[List[Type[torch.optim.Optimizer]], List[Dict[str, Any]]]: """Configures optimizer and lr scheduler.""" - optimizer = self._configure_optimizer() - scheduler = self._configure_lr_scheduler(optimizer) - return [optimizer], [scheduler] + optimizers = self._configure_optimizer() + schedulers = self._configure_lr_scheduler(optimizers) + return optimizers, schedulers def forward(self, data: Tensor) -> Tensor: """Feedforward pass.""" diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 5fb84a7..75f7523 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -60,6 +60,8 @@ class TransformerLitModel(BaseLitModel): pred = self(data) self.val_cer(pred, targets) self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) + self.test_acc(pred, targets) + self.log("val/acc", self.test_acc, on_step=False, on_epoch=True) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" @@ -69,6 +71,8 @@ class TransformerLitModel(BaseLitModel): pred = self(data) self.test_cer(pred, targets) self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) + self.test_acc(pred, targets) + self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) def predict(self, x: Tensor) -> Tensor: """Predicts text in image. diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py new file mode 100644 index 0000000..8ff65cc --- /dev/null +++ b/text_recognizer/models/vqgan.py @@ -0,0 +1,135 @@ +"""PyTorch Lightning model for base Transformers.""" +from typing import Tuple + +import attr +from torch import Tensor + +from text_recognizer.models.base import BaseLitModel +from text_recognizer.criterions.vqgan_loss import VQGANLoss + + +@attr.s(auto_attribs=True, eq=False) +class VQVAELitModel(BaseLitModel): + """A PyTorch Lightning model for transformer networks.""" + + loss_fn: VQGANLoss = attr.ib() + latent_loss_weight: float = attr.ib(default=0.25) + + def forward(self, data: Tensor) -> Tensor: + """Forward pass with the transformer network.""" + return self.network(data) + + def training_step( + self, batch: Tuple[Tensor, Tensor], batch_idx: int, optimizer_idx: int + ) -> Tensor: + """Training step.""" + data, _ = batch + + reconstructions, vq_loss = self(data) + loss = self.loss_fn(reconstructions, data) + + if optimizer_idx == 0: + loss, log = self.loss_fn( + data=data, + reconstructions=reconstructions, + vq_loss=vq_loss, + optimizer_idx=optimizer_idx, + stage="train", + ) + self.log( + "train/loss", + loss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict(log, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return loss + + if optimizer_idx == 1: + loss, log = self.loss_fn( + data=data, + reconstructions=reconstructions, + vq_loss=vq_loss, + optimizer_idx=optimizer_idx, + stage="train", + ) + self.log( + "train/discriminator_loss", + loss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict(log, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return loss + + def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Validation step.""" + data, _ = batch + reconstructions, vq_loss = self(data) + + loss, log = self.loss_fn( + data=data, + reconstructions=reconstructions, + vq_loss=vq_loss, + optimizer_idx=0, + stage="val", + ) + self.log( + "val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True + ) + self.log( + "val/rec_loss", + log["val/rec_loss"], + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict(log) + + _, log = self.loss_fn( + data=data, + reconstructions=reconstructions, + vq_loss=vq_loss, + optimizer_idx=1, + stage="val", + ) + self.log_dict(log) + + def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Test step.""" + data, _ = batch + reconstructions, vq_loss = self(data) + + loss, log = self.loss_fn( + data=data, + reconstructions=reconstructions, + vq_loss=vq_loss, + optimizer_idx=0, + stage="test", + ) + self.log( + "test/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True + ) + self.log( + "test/rec_loss", + log["test/rec_loss"], + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict(log) + + _, log = self.loss_fn( + data=data, + reconstructions=reconstructions, + vq_loss=vq_loss, + optimizer_idx=1, + stage="test", + ) + self.log_dict(log) diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index ef9a59a..56229b3 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -28,8 +28,8 @@ class VQVAELitModel(BaseLitModel): self.log("train/vq_loss", vq_loss) self.log("train/loss", loss) - self.train_acc(reconstructions, data) - self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) + # self.train_acc(reconstructions, data) + # self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) return loss def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: @@ -42,8 +42,8 @@ class VQVAELitModel(BaseLitModel): self.log("val/vq_loss", vq_loss) self.log("val/loss", loss, prog_bar=True) - self.val_acc(reconstructions, data) - self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) + # self.val_acc(reconstructions, data) + # self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" @@ -53,5 +53,5 @@ class VQVAELitModel(BaseLitModel): loss = loss + self.latent_loss_weight * vq_loss self.log("test/vq_loss", vq_loss) self.log("test/loss", loss) - self.test_acc(reconstructions, data) - self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) + # self.test_acc(reconstructions, data) + # self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index f51e0a3..fcbed57 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -12,7 +12,14 @@ from text_recognizer.networks.vqvae.residual import Residual class Decoder(nn.Module): """A CNN encoder network.""" - def __init__(self, out_channels: int, hidden_dim: int, channels_multipliers: Sequence[int], dropout_rate: float, activation: str = "mish") -> None: + def __init__( + self, + out_channels: int, + hidden_dim: int, + channels_multipliers: Sequence[int], + dropout_rate: float, + activation: str = "mish", + ) -> None: super().__init__() self.out_channels = out_channels self.hidden_dim = hidden_dim @@ -33,9 +40,9 @@ class Decoder(nn.Module): use_norm=True, ), ] - + activation_fn = activation_function(self.activation) - out_channels_multipliers = self.channels_multipliers + (1, ) + out_channels_multipliers = self.channels_multipliers + (1,) num_blocks = len(self.channels_multipliers) for i in range(num_blocks): diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index ad8f950..4a5c976 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -11,7 +11,14 @@ from text_recognizer.networks.vqvae.residual import Residual class Encoder(nn.Module): """A CNN encoder network.""" - def __init__(self, in_channels: int, hidden_dim: int, channels_multipliers: List[int], dropout_rate: float, activation: str = "mish") -> None: + def __init__( + self, + in_channels: int, + hidden_dim: int, + channels_multipliers: List[int], + dropout_rate: float, + activation: str = "mish", + ) -> None: super().__init__() self.in_channels = in_channels self.hidden_dim = hidden_dim @@ -33,7 +40,7 @@ class Encoder(nn.Module): ] num_blocks = len(self.channels_multipliers) - channels_multipliers = (1, ) + self.channels_multipliers + channels_multipliers = (1,) + self.channels_multipliers activation_fn = activation_function(self.activation) for i in range(num_blocks): diff --git a/text_recognizer/networks/vqvae/norm.py b/text_recognizer/networks/vqvae/norm.py index 3e6963a..d73f9f8 100644 --- a/text_recognizer/networks/vqvae/norm.py +++ b/text_recognizer/networks/vqvae/norm.py @@ -6,13 +6,17 @@ from torch import nn, Tensor @attr.s(eq=False) class Normalize(nn.Module): num_channels: int = attr.ib() + num_groups: int = attr.ib(default=32) norm: nn.GroupNorm = attr.ib(init=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" super().__init__() self.norm = nn.GroupNorm( - num_groups=self.num_channels, num_channels=self.num_channels, eps=1.0e-6, affine=True + num_groups=self.num_groups, + num_channels=self.num_channels, + eps=1.0e-6, + affine=True, ) def forward(self, x: Tensor) -> Tensor: diff --git a/text_recognizer/networks/vqvae/pixelcnn.py b/text_recognizer/networks/vqvae/pixelcnn.py index 5c580df..b9e6080 100644 --- a/text_recognizer/networks/vqvae/pixelcnn.py +++ b/text_recognizer/networks/vqvae/pixelcnn.py @@ -44,7 +44,7 @@ class Encoder(nn.Module): ), ] num_blocks = len(self.channels_multipliers) - in_channels_multipliers = (1,) + self.channels_multipliers + in_channels_multipliers = (1,) + self.channels_multipliers for i in range(num_blocks): in_channels = self.hidden_dim * in_channels_multipliers[i] out_channels = self.hidden_dim * self.channels_multipliers[i] @@ -68,7 +68,7 @@ class Encoder(nn.Module): dropout_rate=self.dropout_rate, use_norm=True, ), - Attention(in_channels=self.hidden_dim * self.channels_multipliers[-1]) + Attention(in_channels=self.hidden_dim * self.channels_multipliers[-1]), ] encoder += [ @@ -125,7 +125,7 @@ class Decoder(nn.Module): ), ] - out_channels_multipliers = self.channels_multipliers + (1, ) + out_channels_multipliers = self.channels_multipliers + (1,) num_blocks = len(self.channels_multipliers) for i in range(num_blocks): @@ -140,11 +140,7 @@ class Decoder(nn.Module): ) ) if i == 0: - decoder.append( - Attention( - in_channels=out_channels - ) - ) + decoder.append(Attention(in_channels=out_channels)) decoder.append(Upsample()) decoder += [ diff --git a/text_recognizer/networks/vqvae/residual.py b/text_recognizer/networks/vqvae/residual.py index 4ed3781..46b091d 100644 --- a/text_recognizer/networks/vqvae/residual.py +++ b/text_recognizer/networks/vqvae/residual.py @@ -18,7 +18,13 @@ class Residual(nn.Module): super().__init__() self.block = self._build_res_block() if self.in_channels != self.out_channels: - self.conv_shortcut = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=3, stride=1, padding=1) + self.conv_shortcut = nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + ) else: self.conv_shortcut = None diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py index 0646119..e8660c4 100644 --- a/text_recognizer/networks/vqvae/vqvae.py +++ b/text_recognizer/networks/vqvae/vqvae.py @@ -32,7 +32,6 @@ class VQVAE(nn.Module): num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay, ) - def encode(self, x: Tensor) -> Tensor: """Encodes input to a latent code.""" z_e = self.encoder(x) -- cgit v1.2.3-70-g09d2