diff options
| author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-08 19:59:55 +0200 | 
|---|---|---|
| committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-08 19:59:55 +0200 | 
| commit | 240f5e9f20032e82515fa66ce784619527d1041e (patch) | |
| tree | b002d28bbfc9abe9b6af090f7db60bea0aeed6e8 | |
| parent | d12f70402371dda586d457af2a3df7fb5b3130ad (diff) | |
Add VQGAN and loss function
31 files changed, 575 insertions, 113 deletions
@@ -30,6 +30,8 @@ python build-transitions --tokens iamdb_1kwp_tokens_1000.txt --lexicon iamdb_1kw  - [x] Efficient-net b0 + transformer decoder  - [x] Load everything with hydra, get it to work  - [x] Train network +- [ ] Weight init +- [ ] patchgan loss  - [ ] Get VQVAE2 to work and not get loss NAN  - [ ] Local attention for target sequence  - [ ] Rotary embedding for target sequence diff --git a/notebooks/05c-test-model-end-to-end.ipynb b/notebooks/05c-test-model-end-to-end.ipynb index b26a1fe..42621da 100644 --- a/notebooks/05c-test-model-end-to-end.ipynb +++ b/notebooks/05c-test-model-end-to-end.ipynb @@ -2,19 +2,10 @@   "cells": [    {     "cell_type": "code", -   "execution_count": 4, +   "execution_count": 1,     "id": "1e40a88b",     "metadata": {}, -   "outputs": [ -    { -     "name": "stdout", -     "output_type": "stream", -     "text": [ -      "The autoreload extension is already loaded. To reload it, use:\n", -      "  %reload_ext autoreload\n" -     ] -    } -   ], +   "outputs": [],     "source": [      "%load_ext autoreload\n",      "%autoreload 2\n", @@ -34,6 +25,52 @@    },    {     "cell_type": "code", +   "execution_count": 2, +   "id": "f40fc669-829c-4de8-83ed-475fc6a0b8c1", +   "metadata": {}, +   "outputs": [], +   "source": [ +    "class T:\n", +    "    def __init__(self):\n", +    "        self.network = nn.Linear(1, 1)\n", +    "        \n", +    "    def get(self):\n", +    "        return getattr(self, \"network\")" +   ] +  }, +  { +   "cell_type": "code", +   "execution_count": 3, +   "id": "d2bedf96-5388-4c7a-a048-1b97041cbedc", +   "metadata": {}, +   "outputs": [], +   "source": [ +    "t = T()" +   ] +  }, +  { +   "cell_type": "code", +   "execution_count": 5, +   "id": "a6fbe3be-2a9f-4050-a397-7ad982d6cd05", +   "metadata": {}, +   "outputs": [ +    { +     "data": { +      "text/plain": [ +       "<generator object Module.parameters at 0x7f29ad6d6120>" +      ] +     }, +     "execution_count": 5, +     "metadata": {}, +     "output_type": "execute_result" +    } +   ], +   "source": [ +    "t.get().parameters()" +   ] +  }, +  { +   "cell_type": "code",     "execution_count": 5,     "id": "38fb3d9d-a163-4b72-981f-f31b51be39f2",     "metadata": {}, diff --git a/text_recognizer/criterions/n_layer_discriminator.py b/text_recognizer/criterions/n_layer_discriminator.py new file mode 100644 index 0000000..e5f8449 --- /dev/null +++ b/text_recognizer/criterions/n_layer_discriminator.py @@ -0,0 +1,58 @@ +"""Pix2pix discriminator loss.""" +from torch import nn, Tensor + +from text_recognizer.networks.vqvae.norm import Normalize + + +class NLayerDiscriminator(nn.Module): +    """Defines a PatchGAN discriminator loss in Pix2Pix.""" + +    def __init__( +        self, in_channels: int = 1, num_channels: int = 32, num_layers: int = 3 +    ) -> None: +        super().__init__() +        self.in_channels = in_channels +        self.num_channels = num_channels +        self.num_layers = num_layers +        self.discriminator = self._build_discriminator() + +    def _build_discriminator(self) -> nn.Sequential: +        """Builds discriminator.""" +        discriminator = [ +            nn.Conv2d( +                in_channels=self.in_channels, +                out_channels=self.num_channels, +                kernel_size=4, +                stride=2, +                padding=1, +            ), +            nn.Mish(inplace=True), +        ] +        in_channels = self.num_channels +        for n in range(1, self.num_layers): +            discriminator += [ +                nn.Conv2d( +                    in_channels=in_channels, +                    out_channels=in_channels * n, +                    kernel_size=4, +                    stride=2, +                    padding=1, +                ), +                Normalize(num_channels=in_channels * n), +                nn.Mish(inplace=True), +            ] +            in_channels *= n + +        discriminator += [ +            nn.Conv2d( +                in_channels=self.num_channels * (self.num_layers - 1), +                out_channels=1, +                kernel_size=4, +                padding=1, +            ) +        ] +        return nn.Sequential(*discriminator) + +    def forward(self, x: Tensor) -> Tensor: +        """Forward pass through discriminator.""" +        return self.discriminator(x) diff --git a/text_recognizer/criterions/vqgan_loss.py b/text_recognizer/criterions/vqgan_loss.py new file mode 100644 index 0000000..8bb568f --- /dev/null +++ b/text_recognizer/criterions/vqgan_loss.py @@ -0,0 +1,85 @@ +"""VQGAN loss for PyTorch Lightning.""" +from typing import Dict +from click.types import Tuple + +import torch +from torch import nn, Tensor +import torch.nn.functional as F + +from text_recognizer.criterions.n_layer_discriminator import NLayerDiscriminator + + +class VQGANLoss(nn.Module): +    """VQGAN loss.""" + +    def __init__( +        self, +        reconstruction_loss: nn.L1Loss, +        discriminator: NLayerDiscriminator, +        vq_loss_weight: float = 1.0, +        discriminator_weight: float = 1.0, +    ) -> None: +        super().__init__() +        self.reconstruction_loss = reconstruction_loss +        self.discriminator = discriminator +        self.vq_loss_weight = vq_loss_weight +        self.discriminator_weight = discriminator_weight + +    @staticmethod +    def adversarial_loss(logits_real: Tensor, logits_fake: Tensor) -> Tensor: +        """Calculates the adversarial loss.""" +        loss_real = torch.mean(F.relu(1.0 - logits_real)) +        loss_fake = torch.mean(F.relu(1.0 + logits_fake)) +        d_loss = (loss_real + loss_fake) / 2.0 +        return d_loss + +    def forward( +        self, +        data: Tensor, +        reconstructions: Tensor, +        vq_loss: Tensor, +        optimizer_idx: int, +        stage: str, +    ) -> Tuple[Tensor, Dict[str, Tensor]]: +        """Calculates the VQGAN loss.""" +        rec_loss = self.reconstruction_loss( +            data.contiguous(), reconstructions.contiguous() +        ) + +        # GAN part. +        if optimizer_idx == 0: +            logits_fake = self.discriminator(reconstructions.contiguous()) +            g_loss = -torch.mean(logits_fake) + +            loss = ( +                rec_loss +                + self.discriminator_weight * g_loss +                + self.vq_loss_weight * vq_loss +            ) +            log = { +                f"{stage}/loss": loss, +                f"{stage}/vq_loss": vq_loss, +                f"{stage}/rec_loss": rec_loss, +                f"{stage}/g_loss": g_loss, +            } +            return loss, log + +        if optimizer_idx == 1: +            logits_fake = self.discriminator(reconstructions.contiguous().detach()) +            logits_real = self.discriminator(data.contiguous().detach()) + +            d_loss = self.adversarial_loss( +                logits_real=logits_real, logits_fake=logits_fake +            ) +            loss = ( +                rec_loss +                + self.discriminator_weight * d_loss +                + self.vq_loss_weight * vq_loss +            ) +            log = { +                f"{stage}/loss": loss, +                f"{stage}/vq_loss": vq_loss, +                f"{stage}/rec_loss": rec_loss, +                f"{stage}/d_loss": d_loss, +            } +            return loss, log diff --git a/text_recognizer/data/emnist_mapping.py b/text_recognizer/data/emnist_mapping.py index 3e91594..4406db7 100644 --- a/text_recognizer/data/emnist_mapping.py +++ b/text_recognizer/data/emnist_mapping.py @@ -9,7 +9,9 @@ from text_recognizer.data.emnist import emnist_mapping  class EmnistMapping(AbstractMapping): -    def __init__(self, extra_symbols: Optional[Set[str]] = None, lower: bool = True) -> None: +    def __init__( +        self, extra_symbols: Optional[Set[str]] = None, lower: bool = True +    ) -> None:          self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None          self.mapping, self.inverse_mapping, self.input_size = emnist_mapping(              self.extra_symbols @@ -20,10 +22,12 @@ class EmnistMapping(AbstractMapping):      def _to_lower(self) -> None:          """Converts mapping to lowercase letters only.""" +          def _filter(x: int) -> int:              if 40 <= x:                  return x - 26              return x +          self.inverse_mapping = {v: _filter(k) for k, v in enumerate(self.mapping)}          self.mapping = [c for c in self.mapping if not c.isupper()] diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index ab3fa35..8b68ed9 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -24,8 +24,8 @@ class BaseLitModel(LightningModule):      network: Type[nn.Module] = attr.ib()      mapping: Type[AbstractMapping] = attr.ib()      loss_fn: Type[nn.Module] = attr.ib() -    optimizer_config: DictConfig = attr.ib() -    lr_scheduler_config: DictConfig = attr.ib() +    optimizer_configs: DictConfig = attr.ib() +    lr_scheduler_configs: DictConfig = attr.ib()      train_acc: torchmetrics.Accuracy = attr.ib(          init=False, default=torchmetrics.Accuracy()      ) @@ -45,40 +45,55 @@ class BaseLitModel(LightningModule):      ) -> None:          optimizer.zero_grad(set_to_none=True) -    def _configure_optimizer(self) -> Type[torch.optim.Optimizer]: +    def _configure_optimizer(self) -> List[Type[torch.optim.Optimizer]]:          """Configures the optimizer.""" -        log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>") -        return hydra.utils.instantiate( -            self.optimizer_config, params=self.network.parameters() -        ) - -    def _configure_lr_scheduler( -        self, optimizer: Type[torch.optim.Optimizer] -    ) -> Dict[str, Any]: +        optimizers = [] +        for optimizer_config in self.optimizer_configs.values(): +            network = getattr(self, optimizer_config.parameters) +            del optimizer_config.parameters +            log.info(f"Instantiating optimizer <{optimizer_config._target_}>") +            optimizers.append( +                hydra.utils.instantiate( +                    self.optimizer_config, params=network.parameters() +                ) +            ) +        return optimizers + +    def _configure_lr_schedulers( +        self, optimizers: List[Type[torch.optim.Optimizer]] +    ) -> List[Dict[str, Any]]:          """Configures the lr scheduler.""" -        # Extract non-class arguments. -        monitor = self.lr_scheduler_config.monitor -        interval = self.lr_scheduler_config.interval -        del self.lr_scheduler_config.monitor -        del self.lr_scheduler_config.interval - -        log.info( -            f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>" -        ) -        scheduler = { -            "monitor": monitor, -            "interval": interval, -            "scheduler": hydra.utils.instantiate( -                self.lr_scheduler_config, optimizer=optimizer -            ), -        } -        return scheduler - -    def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]: +        schedulers = [] +        for optimizer, lr_scheduler_config in zip( +            optimizers, self.lr_scheduler_configs.values() +        ): +            # Extract non-class arguments. +            monitor = lr_scheduler_config.monitor +            interval = lr_scheduler_config.interval +            del lr_scheduler_config.monitor +            del lr_scheduler_config.interval + +            log.info( +                f"Instantiating learning rate scheduler <{lr_scheduler_config._target_}>" +            ) +            scheduler = { +                "monitor": monitor, +                "interval": interval, +                "scheduler": hydra.utils.instantiate( +                    lr_scheduler_config, optimizer=optimizer +                ), +            } +            schedulers.append(scheduler) + +        return schedulers + +    def configure_optimizers( +        self, +    ) -> Tuple[List[Type[torch.optim.Optimizer]], List[Dict[str, Any]]]:          """Configures optimizer and lr scheduler.""" -        optimizer = self._configure_optimizer() -        scheduler = self._configure_lr_scheduler(optimizer) -        return [optimizer], [scheduler] +        optimizers = self._configure_optimizer() +        schedulers = self._configure_lr_scheduler(optimizers) +        return optimizers, schedulers      def forward(self, data: Tensor) -> Tensor:          """Feedforward pass.""" diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 5fb84a7..75f7523 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -60,6 +60,8 @@ class TransformerLitModel(BaseLitModel):          pred = self(data)          self.val_cer(pred, targets)          self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) +        self.test_acc(pred, targets) +        self.log("val/acc", self.test_acc, on_step=False, on_epoch=True)      def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:          """Test step.""" @@ -69,6 +71,8 @@ class TransformerLitModel(BaseLitModel):          pred = self(data)          self.test_cer(pred, targets)          self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) +        self.test_acc(pred, targets) +        self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)      def predict(self, x: Tensor) -> Tensor:          """Predicts text in image. diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py new file mode 100644 index 0000000..8ff65cc --- /dev/null +++ b/text_recognizer/models/vqgan.py @@ -0,0 +1,135 @@ +"""PyTorch Lightning model for base Transformers.""" +from typing import Tuple + +import attr +from torch import Tensor + +from text_recognizer.models.base import BaseLitModel +from text_recognizer.criterions.vqgan_loss import VQGANLoss + + +@attr.s(auto_attribs=True, eq=False) +class VQVAELitModel(BaseLitModel): +    """A PyTorch Lightning model for transformer networks.""" + +    loss_fn: VQGANLoss = attr.ib() +    latent_loss_weight: float = attr.ib(default=0.25) + +    def forward(self, data: Tensor) -> Tensor: +        """Forward pass with the transformer network.""" +        return self.network(data) + +    def training_step( +        self, batch: Tuple[Tensor, Tensor], batch_idx: int, optimizer_idx: int +    ) -> Tensor: +        """Training step.""" +        data, _ = batch + +        reconstructions, vq_loss = self(data) +        loss = self.loss_fn(reconstructions, data) + +        if optimizer_idx == 0: +            loss, log = self.loss_fn( +                data=data, +                reconstructions=reconstructions, +                vq_loss=vq_loss, +                optimizer_idx=optimizer_idx, +                stage="train", +            ) +            self.log( +                "train/loss", +                loss, +                prog_bar=True, +                logger=True, +                on_step=True, +                on_epoch=True, +            ) +            self.log_dict(log, prog_bar=False, logger=True, on_step=True, on_epoch=True) +            return loss + +        if optimizer_idx == 1: +            loss, log = self.loss_fn( +                data=data, +                reconstructions=reconstructions, +                vq_loss=vq_loss, +                optimizer_idx=optimizer_idx, +                stage="train", +            ) +            self.log( +                "train/discriminator_loss", +                loss, +                prog_bar=True, +                logger=True, +                on_step=True, +                on_epoch=True, +            ) +            self.log_dict(log, prog_bar=False, logger=True, on_step=True, on_epoch=True) +            return loss + +    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: +        """Validation step.""" +        data, _ = batch +        reconstructions, vq_loss = self(data) + +        loss, log = self.loss_fn( +            data=data, +            reconstructions=reconstructions, +            vq_loss=vq_loss, +            optimizer_idx=0, +            stage="val", +        ) +        self.log( +            "val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True +        ) +        self.log( +            "val/rec_loss", +            log["val/rec_loss"], +            prog_bar=True, +            logger=True, +            on_step=True, +            on_epoch=True, +        ) +        self.log_dict(log) + +        _, log = self.loss_fn( +            data=data, +            reconstructions=reconstructions, +            vq_loss=vq_loss, +            optimizer_idx=1, +            stage="val", +        ) +        self.log_dict(log) + +    def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: +        """Test step.""" +        data, _ = batch +        reconstructions, vq_loss = self(data) + +        loss, log = self.loss_fn( +            data=data, +            reconstructions=reconstructions, +            vq_loss=vq_loss, +            optimizer_idx=0, +            stage="test", +        ) +        self.log( +            "test/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True +        ) +        self.log( +            "test/rec_loss", +            log["test/rec_loss"], +            prog_bar=True, +            logger=True, +            on_step=True, +            on_epoch=True, +        ) +        self.log_dict(log) + +        _, log = self.loss_fn( +            data=data, +            reconstructions=reconstructions, +            vq_loss=vq_loss, +            optimizer_idx=1, +            stage="test", +        ) +        self.log_dict(log) diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index ef9a59a..56229b3 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -28,8 +28,8 @@ class VQVAELitModel(BaseLitModel):          self.log("train/vq_loss", vq_loss)          self.log("train/loss", loss) -        self.train_acc(reconstructions, data) -        self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) +        # self.train_acc(reconstructions, data) +        # self.log("train/acc", self.train_acc, on_step=False, on_epoch=True)          return loss      def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: @@ -42,8 +42,8 @@ class VQVAELitModel(BaseLitModel):          self.log("val/vq_loss", vq_loss)          self.log("val/loss", loss, prog_bar=True) -        self.val_acc(reconstructions, data) -        self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) +        # self.val_acc(reconstructions, data) +        # self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)      def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:          """Test step.""" @@ -53,5 +53,5 @@ class VQVAELitModel(BaseLitModel):          loss = loss + self.latent_loss_weight * vq_loss          self.log("test/vq_loss", vq_loss)          self.log("test/loss", loss) -        self.test_acc(reconstructions, data) -        self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) +        # self.test_acc(reconstructions, data) +        # self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index f51e0a3..fcbed57 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -12,7 +12,14 @@ from text_recognizer.networks.vqvae.residual import Residual  class Decoder(nn.Module):      """A CNN encoder network.""" -    def __init__(self, out_channels: int, hidden_dim: int, channels_multipliers: Sequence[int], dropout_rate: float, activation: str = "mish") -> None: +    def __init__( +        self, +        out_channels: int, +        hidden_dim: int, +        channels_multipliers: Sequence[int], +        dropout_rate: float, +        activation: str = "mish", +    ) -> None:          super().__init__()          self.out_channels = out_channels          self.hidden_dim = hidden_dim @@ -33,9 +40,9 @@ class Decoder(nn.Module):                      use_norm=True,                  ),              ] -         +          activation_fn = activation_function(self.activation) -        out_channels_multipliers = self.channels_multipliers + (1, ) +        out_channels_multipliers = self.channels_multipliers + (1,)          num_blocks = len(self.channels_multipliers)          for i in range(num_blocks): diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index ad8f950..4a5c976 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -11,7 +11,14 @@ from text_recognizer.networks.vqvae.residual import Residual  class Encoder(nn.Module):      """A CNN encoder network.""" -    def __init__(self, in_channels: int, hidden_dim: int, channels_multipliers: List[int], dropout_rate: float, activation: str = "mish") -> None: +    def __init__( +        self, +        in_channels: int, +        hidden_dim: int, +        channels_multipliers: List[int], +        dropout_rate: float, +        activation: str = "mish", +    ) -> None:          super().__init__()          self.in_channels = in_channels          self.hidden_dim = hidden_dim @@ -33,7 +40,7 @@ class Encoder(nn.Module):          ]          num_blocks = len(self.channels_multipliers) -        channels_multipliers = (1, ) + self.channels_multipliers +        channels_multipliers = (1,) + self.channels_multipliers          activation_fn = activation_function(self.activation)          for i in range(num_blocks): diff --git a/text_recognizer/networks/vqvae/norm.py b/text_recognizer/networks/vqvae/norm.py index 3e6963a..d73f9f8 100644 --- a/text_recognizer/networks/vqvae/norm.py +++ b/text_recognizer/networks/vqvae/norm.py @@ -6,13 +6,17 @@ from torch import nn, Tensor  @attr.s(eq=False)  class Normalize(nn.Module):      num_channels: int = attr.ib() +    num_groups: int = attr.ib(default=32)      norm: nn.GroupNorm = attr.ib(init=False)      def __attrs_post_init__(self) -> None:          """Post init configuration."""          super().__init__()          self.norm = nn.GroupNorm( -            num_groups=self.num_channels, num_channels=self.num_channels, eps=1.0e-6, affine=True +            num_groups=self.num_groups, +            num_channels=self.num_channels, +            eps=1.0e-6, +            affine=True,          )      def forward(self, x: Tensor) -> Tensor: diff --git a/text_recognizer/networks/vqvae/pixelcnn.py b/text_recognizer/networks/vqvae/pixelcnn.py index 5c580df..b9e6080 100644 --- a/text_recognizer/networks/vqvae/pixelcnn.py +++ b/text_recognizer/networks/vqvae/pixelcnn.py @@ -44,7 +44,7 @@ class Encoder(nn.Module):              ),          ]          num_blocks = len(self.channels_multipliers) -        in_channels_multipliers = (1,) + self.channels_multipliers  +        in_channels_multipliers = (1,) + self.channels_multipliers          for i in range(num_blocks):              in_channels = self.hidden_dim * in_channels_multipliers[i]              out_channels = self.hidden_dim * self.channels_multipliers[i] @@ -68,7 +68,7 @@ class Encoder(nn.Module):                      dropout_rate=self.dropout_rate,                      use_norm=True,                  ), -                Attention(in_channels=self.hidden_dim * self.channels_multipliers[-1]) +                Attention(in_channels=self.hidden_dim * self.channels_multipliers[-1]),              ]          encoder += [ @@ -125,7 +125,7 @@ class Decoder(nn.Module):              ),          ] -        out_channels_multipliers = self.channels_multipliers + (1, ) +        out_channels_multipliers = self.channels_multipliers + (1,)          num_blocks = len(self.channels_multipliers)          for i in range(num_blocks): @@ -140,11 +140,7 @@ class Decoder(nn.Module):                  )              )              if i == 0: -                decoder.append( -                    Attention( -                        in_channels=out_channels -                    ) -                ) +                decoder.append(Attention(in_channels=out_channels))              decoder.append(Upsample())          decoder += [ diff --git a/text_recognizer/networks/vqvae/residual.py b/text_recognizer/networks/vqvae/residual.py index 4ed3781..46b091d 100644 --- a/text_recognizer/networks/vqvae/residual.py +++ b/text_recognizer/networks/vqvae/residual.py @@ -18,7 +18,13 @@ class Residual(nn.Module):          super().__init__()          self.block = self._build_res_block()          if self.in_channels != self.out_channels: -            self.conv_shortcut = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=3, stride=1, padding=1) +            self.conv_shortcut = nn.Conv2d( +                in_channels=self.in_channels, +                out_channels=self.out_channels, +                kernel_size=3, +                stride=1, +                padding=1, +            )          else:              self.conv_shortcut = None diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py index 0646119..e8660c4 100644 --- a/text_recognizer/networks/vqvae/vqvae.py +++ b/text_recognizer/networks/vqvae/vqvae.py @@ -32,7 +32,6 @@ class VQVAE(nn.Module):              num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay,          ) -      def encode(self, x: Tensor) -> Tensor:          """Encodes input to a latent code."""          z_e = self.encoder(x) diff --git a/training/conf/config.yaml b/training/conf/config.yaml index c606366..5897d87 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -6,11 +6,13 @@ defaults:    - datamodule: iam_extended_paragraphs    - hydra: default    - logger: wandb -  - lr_scheduler: one_cycle +  - lr_schedulers:  +      - one_cycle    - mapping: word_piece    - model: lit_transformer    - network: conv_transformer -  - optimizer: madgrad +  - optimizers:  +      - madgrad    - trainer: default  seed: 4711 @@ -32,7 +34,9 @@ work_dir: ${hydra:runtime.cwd}  debug: False  # pretty print config at the start of the run using Rich library -print_config: True +print_config: false  # disable python warnings if they annoy you -ignore_warnings: True +ignore_warnings: true + +summary: null  # [1, 576, 640] diff --git a/training/conf/criterion/mae.yaml b/training/conf/criterion/mae.yaml new file mode 100644 index 0000000..cb07467 --- /dev/null +++ b/training/conf/criterion/mae.yaml @@ -0,0 +1,2 @@ +_target_: torch.nn.L1Loss +reduction: mean diff --git a/training/conf/criterion/vqgan_loss.yaml b/training/conf/criterion/vqgan_loss.yaml new file mode 100644 index 0000000..a1c886e --- /dev/null +++ b/training/conf/criterion/vqgan_loss.yaml @@ -0,0 +1,12 @@ +_target_: text_recognizer.criterions.vqgan_loss.VQGANLoss +reconstruction_loss: +  _target_: torch.nn.L1Loss +  reduction: mean +discriminator: +  _target_: text_recognizer.criterions.n_layer_discriminator.NLayerDiscriminator +  in_channels: 1 +  num_channels: 32 +  num_layers: 3 +vq_loss_weight: 1.0 +discriminator_weight: 1.0 + diff --git a/training/conf/experiment/vqgan.yaml b/training/conf/experiment/vqgan.yaml new file mode 100644 index 0000000..3d97892 --- /dev/null +++ b/training/conf/experiment/vqgan.yaml @@ -0,0 +1,55 @@ +# @package _global_ + +defaults: +  - override /network: vqvae +  - override /criterion: vqgan_loss +  - override /model: lit_vqgan +  - override /callbacks: wandb_vae +  - override /lr_schedulers: null + +datamodule: +  batch_size: 8 + +lr_schedulers: +  - generator: +    T_max: 256 +    eta_min: 0.0 +    last_epoch: -1 + +    interval: epoch +    monitor: val/loss + +  - discriminator: +    T_max: 256 +    eta_min: 0.0 +    last_epoch: -1 + +    interval: epoch +    monitor: val/loss + +optimizer: +  - generator: +    _target_: torch.optim.lr_scheduler.CosineAnnealingLR +    T_max: 256 +    eta_min: 0.0 +    last_epoch: -1 + +    interval: epoch +    monitor: val/loss +    parameters: network + +  - discriminator: +    _target_: torch.optim.lr_scheduler.CosineAnnealingLR +    T_max: 256 +    eta_min: 0.0 +    last_epoch: -1 + +    interval: epoch +    monitor: val/loss +    parameters: loss_fn + +trainer: +  max_epochs: 256 +  # gradient_clip_val: 0.25 + +summary: null diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml index 7a9e643..397a039 100644 --- a/training/conf/experiment/vqvae.yaml +++ b/training/conf/experiment/vqvae.yaml @@ -2,17 +2,18 @@  defaults:    - override /network: vqvae -  - override /criterion: mse +  - override /criterion: mae    - override /model: lit_vqvae    - override /callbacks: wandb_vae -  - override /lr_scheduler: cosine_annealing +  - override /lr_schedulers:  +      - cosine_annealing  trainer: -  max_epochs: 64 +  max_epochs: 256    # gradient_clip_val: 0.25  datamodule: -  batch_size: 16 +  batch_size: 8  # lr_scheduler:    # epochs: 64 @@ -21,4 +22,4 @@ datamodule:  # optimizer:    # lr: 1.0e-3 -summary: [1, 576, 640] +summary: null diff --git a/training/conf/experiment/vqvae_pixelcnn.yaml b/training/conf/experiment/vqvae_pixelcnn.yaml new file mode 100644 index 0000000..4fae782 --- /dev/null +++ b/training/conf/experiment/vqvae_pixelcnn.yaml @@ -0,0 +1,24 @@ +# @package _global_ + +defaults: +  - override /network: vqvae_pixelcnn +  - override /criterion: mae +  - override /model: lit_vqvae +  - override /callbacks: wandb_vae +  - override /lr_schedulers:  +      - cosine_annealing + +trainer: +  max_epochs: 256 +  # gradient_clip_val: 0.25 + +datamodule: +  batch_size: 8 + +# lr_scheduler: +  # epochs: 64 +  # steps_per_epoch: 1245 + +# optimizer: +  # lr: 1.0e-3 + diff --git a/training/conf/lr_scheduler/cosine_annealing.yaml b/training/conf/lr_scheduler/cosine_annealing.yaml index 62667bb..c53ee3a 100644 --- a/training/conf/lr_scheduler/cosine_annealing.yaml +++ b/training/conf/lr_scheduler/cosine_annealing.yaml @@ -1,7 +1,8 @@ -_target_: torch.optim.lr_scheduler.CosineAnnealingLR -T_max: 64 -eta_min: 0.0 -last_epoch: -1 +cosine_annealing: +  _target_: torch.optim.lr_scheduler.CosineAnnealingLR +  T_max: 256 +  eta_min: 0.0 +  last_epoch: -1 -interval: epoch -monitor: val/loss +  interval: epoch +  monitor: val/loss diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_scheduler/one_cycle.yaml index fb5987a..c60577a 100644 --- a/training/conf/lr_scheduler/one_cycle.yaml +++ b/training/conf/lr_scheduler/one_cycle.yaml @@ -1,19 +1,20 @@ -_target_: torch.optim.lr_scheduler.OneCycleLR -max_lr: 1.0e-3 -total_steps: null -epochs: 512 -steps_per_epoch: 4992 -pct_start: 0.3 -anneal_strategy: cos -cycle_momentum: true -base_momentum: 0.85 -max_momentum: 0.95 -div_factor: 25.0 -final_div_factor: 10000.0 -three_phase: true -last_epoch: -1 -verbose: false +onc_cycle: +  _target_: torch.optim.lr_scheduler.OneCycleLR +  max_lr: 1.0e-3 +  total_steps: null +  epochs: 512 +  steps_per_epoch: 4992 +  pct_start: 0.3 +  anneal_strategy: cos +  cycle_momentum: true +  base_momentum: 0.85 +  max_momentum: 0.95 +  div_factor: 25.0 +  final_div_factor: 10000.0 +  three_phase: true +  last_epoch: -1 +  verbose: false -# Non-class arguments -interval: step -monitor: val/loss +  # Non-class arguments +  interval: step +  monitor: val/loss diff --git a/training/conf/network/encoder/pixelcnn_decoder.yaml b/training/conf/network/decoder/pixelcnn_decoder.yaml index 3895164..cdddb7a 100644 --- a/training/conf/network/encoder/pixelcnn_decoder.yaml +++ b/training/conf/network/decoder/pixelcnn_decoder.yaml @@ -1,5 +1,5 @@  _target_: text_recognizer.networks.vqvae.pixelcnn.Decoder  out_channels: 1   hidden_dim: 8 -channels_multipliers: [8, 8, 2, 1] +channels_multipliers: [8, 2, 1]  dropout_rate: 0.25 diff --git a/training/conf/network/decoder/vae_decoder.yaml b/training/conf/network/decoder/vae_decoder.yaml index 0a36a54..a5e7286 100644 --- a/training/conf/network/decoder/vae_decoder.yaml +++ b/training/conf/network/decoder/vae_decoder.yaml @@ -1,5 +1,5 @@  _target_: text_recognizer.networks.vqvae.decoder.Decoder  out_channels: 1   hidden_dim: 32 -channels_multipliers: [4, 4, 2, 1] +channels_multipliers: [8, 8, 4, 1]  dropout_rate: 0.25 diff --git a/training/conf/network/decoder/pixelcnn_encoder.yaml b/training/conf/network/encoder/pixelcnn_encoder.yaml index 47a130d..f12957b 100644 --- a/training/conf/network/decoder/pixelcnn_encoder.yaml +++ b/training/conf/network/encoder/pixelcnn_encoder.yaml @@ -1,5 +1,5 @@  _target_: text_recognizer.networks.vqvae.pixelcnn.Encoder  in_channels: 1   hidden_dim: 8 -channels_multipliers: [1, 2, 8, 8] +channels_multipliers: [1, 2, 8]  dropout_rate: 0.25 diff --git a/training/conf/network/encoder/vae_encoder.yaml b/training/conf/network/encoder/vae_encoder.yaml index dacd389..58e905d 100644 --- a/training/conf/network/encoder/vae_encoder.yaml +++ b/training/conf/network/encoder/vae_encoder.yaml @@ -1,5 +1,5 @@  _target_: text_recognizer.networks.vqvae.encoder.Encoder  in_channels: 1   hidden_dim: 32 -channels_multipliers: [1, 2, 4, 4] +channels_multipliers: [1, 2, 4, 8, 8]  dropout_rate: 0.25 diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml index d97e9b6..835d0b7 100644 --- a/training/conf/network/vqvae.yaml +++ b/training/conf/network/vqvae.yaml @@ -3,7 +3,7 @@ defaults:    - decoder: vae_decoder  _target_: text_recognizer.networks.vqvae.vqvae.VQVAE -hidden_dim: 128 +hidden_dim: 256  embedding_dim: 32  num_embeddings: 1024  decay: 0.99 diff --git a/training/conf/network/vqvae_pixelcnn.yaml b/training/conf/network/vqvae_pixelcnn.yaml index 10200bc..cd850af 100644 --- a/training/conf/network/vqvae_pixelcnn.yaml +++ b/training/conf/network/vqvae_pixelcnn.yaml @@ -5,5 +5,5 @@ defaults:  _target_: text_recognizer.networks.vqvae.vqvae.VQVAE  hidden_dim: 64  embedding_dim: 32 -num_embeddings: 512 +num_embeddings: 1024  decay: 0.99 diff --git a/training/conf/optimizer/madgrad.yaml b/training/conf/optimizer/madgrad.yaml index 458b116..a6c059d 100644 --- a/training/conf/optimizer/madgrad.yaml +++ b/training/conf/optimizer/madgrad.yaml @@ -1,5 +1,8 @@ -_target_: madgrad.MADGRAD -lr: 3.0e-4 -momentum: 0.9 -weight_decay: 0 -eps: 1.0e-6 +madgrad: +  _target_: madgrad.MADGRAD +  lr: 1.0e-3 +  momentum: 0.9 +  weight_decay: 0 +  eps: 1.0e-6 + +  parameters: network diff --git a/training/run.py b/training/run.py index a2529b0..0cf52e3 100644 --- a/training/run.py +++ b/training/run.py @@ -50,8 +50,8 @@ def run(config: DictConfig) -> Optional[float]:          mapping=mapping,          network=network,          loss_fn=loss_fn, -        optimizer_config=config.optimizer, -        lr_scheduler_config=config.lr_scheduler, +        optimizer_configs=config.optimizers, +        lr_scheduler_configs=config.lr_schedulers,          _recursive_=False,      )  |