diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-08 19:59:55 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-08-08 19:59:55 +0200 |
commit | 240f5e9f20032e82515fa66ce784619527d1041e (patch) | |
tree | b002d28bbfc9abe9b6af090f7db60bea0aeed6e8 | |
parent | d12f70402371dda586d457af2a3df7fb5b3130ad (diff) |
Add VQGAN and loss function
31 files changed, 575 insertions, 113 deletions
@@ -30,6 +30,8 @@ python build-transitions --tokens iamdb_1kwp_tokens_1000.txt --lexicon iamdb_1kw - [x] Efficient-net b0 + transformer decoder - [x] Load everything with hydra, get it to work - [x] Train network +- [ ] Weight init +- [ ] patchgan loss - [ ] Get VQVAE2 to work and not get loss NAN - [ ] Local attention for target sequence - [ ] Rotary embedding for target sequence diff --git a/notebooks/05c-test-model-end-to-end.ipynb b/notebooks/05c-test-model-end-to-end.ipynb index b26a1fe..42621da 100644 --- a/notebooks/05c-test-model-end-to-end.ipynb +++ b/notebooks/05c-test-model-end-to-end.ipynb @@ -2,19 +2,10 @@ "cells": [ { "cell_type": "code", - "execution_count": 4, + "execution_count": 1, "id": "1e40a88b", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "outputs": [], "source": [ "%load_ext autoreload\n", "%autoreload 2\n", @@ -34,6 +25,52 @@ }, { "cell_type": "code", + "execution_count": 2, + "id": "f40fc669-829c-4de8-83ed-475fc6a0b8c1", + "metadata": {}, + "outputs": [], + "source": [ + "class T:\n", + " def __init__(self):\n", + " self.network = nn.Linear(1, 1)\n", + " \n", + " def get(self):\n", + " return getattr(self, \"network\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "d2bedf96-5388-4c7a-a048-1b97041cbedc", + "metadata": {}, + "outputs": [], + "source": [ + "t = T()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "a6fbe3be-2a9f-4050-a397-7ad982d6cd05", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "<generator object Module.parameters at 0x7f29ad6d6120>" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "t.get().parameters()" + ] + }, + { + "cell_type": "code", "execution_count": 5, "id": "38fb3d9d-a163-4b72-981f-f31b51be39f2", "metadata": {}, diff --git a/text_recognizer/criterions/n_layer_discriminator.py b/text_recognizer/criterions/n_layer_discriminator.py new file mode 100644 index 0000000..e5f8449 --- /dev/null +++ b/text_recognizer/criterions/n_layer_discriminator.py @@ -0,0 +1,58 @@ +"""Pix2pix discriminator loss.""" +from torch import nn, Tensor + +from text_recognizer.networks.vqvae.norm import Normalize + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator loss in Pix2Pix.""" + + def __init__( + self, in_channels: int = 1, num_channels: int = 32, num_layers: int = 3 + ) -> None: + super().__init__() + self.in_channels = in_channels + self.num_channels = num_channels + self.num_layers = num_layers + self.discriminator = self._build_discriminator() + + def _build_discriminator(self) -> nn.Sequential: + """Builds discriminator.""" + discriminator = [ + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.num_channels, + kernel_size=4, + stride=2, + padding=1, + ), + nn.Mish(inplace=True), + ] + in_channels = self.num_channels + for n in range(1, self.num_layers): + discriminator += [ + nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels * n, + kernel_size=4, + stride=2, + padding=1, + ), + Normalize(num_channels=in_channels * n), + nn.Mish(inplace=True), + ] + in_channels *= n + + discriminator += [ + nn.Conv2d( + in_channels=self.num_channels * (self.num_layers - 1), + out_channels=1, + kernel_size=4, + padding=1, + ) + ] + return nn.Sequential(*discriminator) + + def forward(self, x: Tensor) -> Tensor: + """Forward pass through discriminator.""" + return self.discriminator(x) diff --git a/text_recognizer/criterions/vqgan_loss.py b/text_recognizer/criterions/vqgan_loss.py new file mode 100644 index 0000000..8bb568f --- /dev/null +++ b/text_recognizer/criterions/vqgan_loss.py @@ -0,0 +1,85 @@ +"""VQGAN loss for PyTorch Lightning.""" +from typing import Dict +from click.types import Tuple + +import torch +from torch import nn, Tensor +import torch.nn.functional as F + +from text_recognizer.criterions.n_layer_discriminator import NLayerDiscriminator + + +class VQGANLoss(nn.Module): + """VQGAN loss.""" + + def __init__( + self, + reconstruction_loss: nn.L1Loss, + discriminator: NLayerDiscriminator, + vq_loss_weight: float = 1.0, + discriminator_weight: float = 1.0, + ) -> None: + super().__init__() + self.reconstruction_loss = reconstruction_loss + self.discriminator = discriminator + self.vq_loss_weight = vq_loss_weight + self.discriminator_weight = discriminator_weight + + @staticmethod + def adversarial_loss(logits_real: Tensor, logits_fake: Tensor) -> Tensor: + """Calculates the adversarial loss.""" + loss_real = torch.mean(F.relu(1.0 - logits_real)) + loss_fake = torch.mean(F.relu(1.0 + logits_fake)) + d_loss = (loss_real + loss_fake) / 2.0 + return d_loss + + def forward( + self, + data: Tensor, + reconstructions: Tensor, + vq_loss: Tensor, + optimizer_idx: int, + stage: str, + ) -> Tuple[Tensor, Dict[str, Tensor]]: + """Calculates the VQGAN loss.""" + rec_loss = self.reconstruction_loss( + data.contiguous(), reconstructions.contiguous() + ) + + # GAN part. + if optimizer_idx == 0: + logits_fake = self.discriminator(reconstructions.contiguous()) + g_loss = -torch.mean(logits_fake) + + loss = ( + rec_loss + + self.discriminator_weight * g_loss + + self.vq_loss_weight * vq_loss + ) + log = { + f"{stage}/loss": loss, + f"{stage}/vq_loss": vq_loss, + f"{stage}/rec_loss": rec_loss, + f"{stage}/g_loss": g_loss, + } + return loss, log + + if optimizer_idx == 1: + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + logits_real = self.discriminator(data.contiguous().detach()) + + d_loss = self.adversarial_loss( + logits_real=logits_real, logits_fake=logits_fake + ) + loss = ( + rec_loss + + self.discriminator_weight * d_loss + + self.vq_loss_weight * vq_loss + ) + log = { + f"{stage}/loss": loss, + f"{stage}/vq_loss": vq_loss, + f"{stage}/rec_loss": rec_loss, + f"{stage}/d_loss": d_loss, + } + return loss, log diff --git a/text_recognizer/data/emnist_mapping.py b/text_recognizer/data/emnist_mapping.py index 3e91594..4406db7 100644 --- a/text_recognizer/data/emnist_mapping.py +++ b/text_recognizer/data/emnist_mapping.py @@ -9,7 +9,9 @@ from text_recognizer.data.emnist import emnist_mapping class EmnistMapping(AbstractMapping): - def __init__(self, extra_symbols: Optional[Set[str]] = None, lower: bool = True) -> None: + def __init__( + self, extra_symbols: Optional[Set[str]] = None, lower: bool = True + ) -> None: self.extra_symbols = set(extra_symbols) if extra_symbols is not None else None self.mapping, self.inverse_mapping, self.input_size = emnist_mapping( self.extra_symbols @@ -20,10 +22,12 @@ class EmnistMapping(AbstractMapping): def _to_lower(self) -> None: """Converts mapping to lowercase letters only.""" + def _filter(x: int) -> int: if 40 <= x: return x - 26 return x + self.inverse_mapping = {v: _filter(k) for k, v in enumerate(self.mapping)} self.mapping = [c for c in self.mapping if not c.isupper()] diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index ab3fa35..8b68ed9 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -24,8 +24,8 @@ class BaseLitModel(LightningModule): network: Type[nn.Module] = attr.ib() mapping: Type[AbstractMapping] = attr.ib() loss_fn: Type[nn.Module] = attr.ib() - optimizer_config: DictConfig = attr.ib() - lr_scheduler_config: DictConfig = attr.ib() + optimizer_configs: DictConfig = attr.ib() + lr_scheduler_configs: DictConfig = attr.ib() train_acc: torchmetrics.Accuracy = attr.ib( init=False, default=torchmetrics.Accuracy() ) @@ -45,40 +45,55 @@ class BaseLitModel(LightningModule): ) -> None: optimizer.zero_grad(set_to_none=True) - def _configure_optimizer(self) -> Type[torch.optim.Optimizer]: + def _configure_optimizer(self) -> List[Type[torch.optim.Optimizer]]: """Configures the optimizer.""" - log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>") - return hydra.utils.instantiate( - self.optimizer_config, params=self.network.parameters() - ) - - def _configure_lr_scheduler( - self, optimizer: Type[torch.optim.Optimizer] - ) -> Dict[str, Any]: + optimizers = [] + for optimizer_config in self.optimizer_configs.values(): + network = getattr(self, optimizer_config.parameters) + del optimizer_config.parameters + log.info(f"Instantiating optimizer <{optimizer_config._target_}>") + optimizers.append( + hydra.utils.instantiate( + self.optimizer_config, params=network.parameters() + ) + ) + return optimizers + + def _configure_lr_schedulers( + self, optimizers: List[Type[torch.optim.Optimizer]] + ) -> List[Dict[str, Any]]: """Configures the lr scheduler.""" - # Extract non-class arguments. - monitor = self.lr_scheduler_config.monitor - interval = self.lr_scheduler_config.interval - del self.lr_scheduler_config.monitor - del self.lr_scheduler_config.interval - - log.info( - f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>" - ) - scheduler = { - "monitor": monitor, - "interval": interval, - "scheduler": hydra.utils.instantiate( - self.lr_scheduler_config, optimizer=optimizer - ), - } - return scheduler - - def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]: + schedulers = [] + for optimizer, lr_scheduler_config in zip( + optimizers, self.lr_scheduler_configs.values() + ): + # Extract non-class arguments. + monitor = lr_scheduler_config.monitor + interval = lr_scheduler_config.interval + del lr_scheduler_config.monitor + del lr_scheduler_config.interval + + log.info( + f"Instantiating learning rate scheduler <{lr_scheduler_config._target_}>" + ) + scheduler = { + "monitor": monitor, + "interval": interval, + "scheduler": hydra.utils.instantiate( + lr_scheduler_config, optimizer=optimizer + ), + } + schedulers.append(scheduler) + + return schedulers + + def configure_optimizers( + self, + ) -> Tuple[List[Type[torch.optim.Optimizer]], List[Dict[str, Any]]]: """Configures optimizer and lr scheduler.""" - optimizer = self._configure_optimizer() - scheduler = self._configure_lr_scheduler(optimizer) - return [optimizer], [scheduler] + optimizers = self._configure_optimizer() + schedulers = self._configure_lr_scheduler(optimizers) + return optimizers, schedulers def forward(self, data: Tensor) -> Tensor: """Feedforward pass.""" diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index 5fb84a7..75f7523 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -60,6 +60,8 @@ class TransformerLitModel(BaseLitModel): pred = self(data) self.val_cer(pred, targets) self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) + self.test_acc(pred, targets) + self.log("val/acc", self.test_acc, on_step=False, on_epoch=True) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" @@ -69,6 +71,8 @@ class TransformerLitModel(BaseLitModel): pred = self(data) self.test_cer(pred, targets) self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) + self.test_acc(pred, targets) + self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) def predict(self, x: Tensor) -> Tensor: """Predicts text in image. diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py new file mode 100644 index 0000000..8ff65cc --- /dev/null +++ b/text_recognizer/models/vqgan.py @@ -0,0 +1,135 @@ +"""PyTorch Lightning model for base Transformers.""" +from typing import Tuple + +import attr +from torch import Tensor + +from text_recognizer.models.base import BaseLitModel +from text_recognizer.criterions.vqgan_loss import VQGANLoss + + +@attr.s(auto_attribs=True, eq=False) +class VQVAELitModel(BaseLitModel): + """A PyTorch Lightning model for transformer networks.""" + + loss_fn: VQGANLoss = attr.ib() + latent_loss_weight: float = attr.ib(default=0.25) + + def forward(self, data: Tensor) -> Tensor: + """Forward pass with the transformer network.""" + return self.network(data) + + def training_step( + self, batch: Tuple[Tensor, Tensor], batch_idx: int, optimizer_idx: int + ) -> Tensor: + """Training step.""" + data, _ = batch + + reconstructions, vq_loss = self(data) + loss = self.loss_fn(reconstructions, data) + + if optimizer_idx == 0: + loss, log = self.loss_fn( + data=data, + reconstructions=reconstructions, + vq_loss=vq_loss, + optimizer_idx=optimizer_idx, + stage="train", + ) + self.log( + "train/loss", + loss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict(log, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return loss + + if optimizer_idx == 1: + loss, log = self.loss_fn( + data=data, + reconstructions=reconstructions, + vq_loss=vq_loss, + optimizer_idx=optimizer_idx, + stage="train", + ) + self.log( + "train/discriminator_loss", + loss, + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict(log, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return loss + + def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Validation step.""" + data, _ = batch + reconstructions, vq_loss = self(data) + + loss, log = self.loss_fn( + data=data, + reconstructions=reconstructions, + vq_loss=vq_loss, + optimizer_idx=0, + stage="val", + ) + self.log( + "val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True + ) + self.log( + "val/rec_loss", + log["val/rec_loss"], + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict(log) + + _, log = self.loss_fn( + data=data, + reconstructions=reconstructions, + vq_loss=vq_loss, + optimizer_idx=1, + stage="val", + ) + self.log_dict(log) + + def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Test step.""" + data, _ = batch + reconstructions, vq_loss = self(data) + + loss, log = self.loss_fn( + data=data, + reconstructions=reconstructions, + vq_loss=vq_loss, + optimizer_idx=0, + stage="test", + ) + self.log( + "test/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True + ) + self.log( + "test/rec_loss", + log["test/rec_loss"], + prog_bar=True, + logger=True, + on_step=True, + on_epoch=True, + ) + self.log_dict(log) + + _, log = self.loss_fn( + data=data, + reconstructions=reconstructions, + vq_loss=vq_loss, + optimizer_idx=1, + stage="test", + ) + self.log_dict(log) diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index ef9a59a..56229b3 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -28,8 +28,8 @@ class VQVAELitModel(BaseLitModel): self.log("train/vq_loss", vq_loss) self.log("train/loss", loss) - self.train_acc(reconstructions, data) - self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) + # self.train_acc(reconstructions, data) + # self.log("train/acc", self.train_acc, on_step=False, on_epoch=True) return loss def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: @@ -42,8 +42,8 @@ class VQVAELitModel(BaseLitModel): self.log("val/vq_loss", vq_loss) self.log("val/loss", loss, prog_bar=True) - self.val_acc(reconstructions, data) - self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) + # self.val_acc(reconstructions, data) + # self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" @@ -53,5 +53,5 @@ class VQVAELitModel(BaseLitModel): loss = loss + self.latent_loss_weight * vq_loss self.log("test/vq_loss", vq_loss) self.log("test/loss", loss) - self.test_acc(reconstructions, data) - self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) + # self.test_acc(reconstructions, data) + # self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) diff --git a/text_recognizer/networks/vqvae/decoder.py b/text_recognizer/networks/vqvae/decoder.py index f51e0a3..fcbed57 100644 --- a/text_recognizer/networks/vqvae/decoder.py +++ b/text_recognizer/networks/vqvae/decoder.py @@ -12,7 +12,14 @@ from text_recognizer.networks.vqvae.residual import Residual class Decoder(nn.Module): """A CNN encoder network.""" - def __init__(self, out_channels: int, hidden_dim: int, channels_multipliers: Sequence[int], dropout_rate: float, activation: str = "mish") -> None: + def __init__( + self, + out_channels: int, + hidden_dim: int, + channels_multipliers: Sequence[int], + dropout_rate: float, + activation: str = "mish", + ) -> None: super().__init__() self.out_channels = out_channels self.hidden_dim = hidden_dim @@ -33,9 +40,9 @@ class Decoder(nn.Module): use_norm=True, ), ] - + activation_fn = activation_function(self.activation) - out_channels_multipliers = self.channels_multipliers + (1, ) + out_channels_multipliers = self.channels_multipliers + (1,) num_blocks = len(self.channels_multipliers) for i in range(num_blocks): diff --git a/text_recognizer/networks/vqvae/encoder.py b/text_recognizer/networks/vqvae/encoder.py index ad8f950..4a5c976 100644 --- a/text_recognizer/networks/vqvae/encoder.py +++ b/text_recognizer/networks/vqvae/encoder.py @@ -11,7 +11,14 @@ from text_recognizer.networks.vqvae.residual import Residual class Encoder(nn.Module): """A CNN encoder network.""" - def __init__(self, in_channels: int, hidden_dim: int, channels_multipliers: List[int], dropout_rate: float, activation: str = "mish") -> None: + def __init__( + self, + in_channels: int, + hidden_dim: int, + channels_multipliers: List[int], + dropout_rate: float, + activation: str = "mish", + ) -> None: super().__init__() self.in_channels = in_channels self.hidden_dim = hidden_dim @@ -33,7 +40,7 @@ class Encoder(nn.Module): ] num_blocks = len(self.channels_multipliers) - channels_multipliers = (1, ) + self.channels_multipliers + channels_multipliers = (1,) + self.channels_multipliers activation_fn = activation_function(self.activation) for i in range(num_blocks): diff --git a/text_recognizer/networks/vqvae/norm.py b/text_recognizer/networks/vqvae/norm.py index 3e6963a..d73f9f8 100644 --- a/text_recognizer/networks/vqvae/norm.py +++ b/text_recognizer/networks/vqvae/norm.py @@ -6,13 +6,17 @@ from torch import nn, Tensor @attr.s(eq=False) class Normalize(nn.Module): num_channels: int = attr.ib() + num_groups: int = attr.ib(default=32) norm: nn.GroupNorm = attr.ib(init=False) def __attrs_post_init__(self) -> None: """Post init configuration.""" super().__init__() self.norm = nn.GroupNorm( - num_groups=self.num_channels, num_channels=self.num_channels, eps=1.0e-6, affine=True + num_groups=self.num_groups, + num_channels=self.num_channels, + eps=1.0e-6, + affine=True, ) def forward(self, x: Tensor) -> Tensor: diff --git a/text_recognizer/networks/vqvae/pixelcnn.py b/text_recognizer/networks/vqvae/pixelcnn.py index 5c580df..b9e6080 100644 --- a/text_recognizer/networks/vqvae/pixelcnn.py +++ b/text_recognizer/networks/vqvae/pixelcnn.py @@ -44,7 +44,7 @@ class Encoder(nn.Module): ), ] num_blocks = len(self.channels_multipliers) - in_channels_multipliers = (1,) + self.channels_multipliers + in_channels_multipliers = (1,) + self.channels_multipliers for i in range(num_blocks): in_channels = self.hidden_dim * in_channels_multipliers[i] out_channels = self.hidden_dim * self.channels_multipliers[i] @@ -68,7 +68,7 @@ class Encoder(nn.Module): dropout_rate=self.dropout_rate, use_norm=True, ), - Attention(in_channels=self.hidden_dim * self.channels_multipliers[-1]) + Attention(in_channels=self.hidden_dim * self.channels_multipliers[-1]), ] encoder += [ @@ -125,7 +125,7 @@ class Decoder(nn.Module): ), ] - out_channels_multipliers = self.channels_multipliers + (1, ) + out_channels_multipliers = self.channels_multipliers + (1,) num_blocks = len(self.channels_multipliers) for i in range(num_blocks): @@ -140,11 +140,7 @@ class Decoder(nn.Module): ) ) if i == 0: - decoder.append( - Attention( - in_channels=out_channels - ) - ) + decoder.append(Attention(in_channels=out_channels)) decoder.append(Upsample()) decoder += [ diff --git a/text_recognizer/networks/vqvae/residual.py b/text_recognizer/networks/vqvae/residual.py index 4ed3781..46b091d 100644 --- a/text_recognizer/networks/vqvae/residual.py +++ b/text_recognizer/networks/vqvae/residual.py @@ -18,7 +18,13 @@ class Residual(nn.Module): super().__init__() self.block = self._build_res_block() if self.in_channels != self.out_channels: - self.conv_shortcut = nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=3, stride=1, padding=1) + self.conv_shortcut = nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + ) else: self.conv_shortcut = None diff --git a/text_recognizer/networks/vqvae/vqvae.py b/text_recognizer/networks/vqvae/vqvae.py index 0646119..e8660c4 100644 --- a/text_recognizer/networks/vqvae/vqvae.py +++ b/text_recognizer/networks/vqvae/vqvae.py @@ -32,7 +32,6 @@ class VQVAE(nn.Module): num_embeddings=num_embeddings, embedding_dim=embedding_dim, decay=decay, ) - def encode(self, x: Tensor) -> Tensor: """Encodes input to a latent code.""" z_e = self.encoder(x) diff --git a/training/conf/config.yaml b/training/conf/config.yaml index c606366..5897d87 100644 --- a/training/conf/config.yaml +++ b/training/conf/config.yaml @@ -6,11 +6,13 @@ defaults: - datamodule: iam_extended_paragraphs - hydra: default - logger: wandb - - lr_scheduler: one_cycle + - lr_schedulers: + - one_cycle - mapping: word_piece - model: lit_transformer - network: conv_transformer - - optimizer: madgrad + - optimizers: + - madgrad - trainer: default seed: 4711 @@ -32,7 +34,9 @@ work_dir: ${hydra:runtime.cwd} debug: False # pretty print config at the start of the run using Rich library -print_config: True +print_config: false # disable python warnings if they annoy you -ignore_warnings: True +ignore_warnings: true + +summary: null # [1, 576, 640] diff --git a/training/conf/criterion/mae.yaml b/training/conf/criterion/mae.yaml new file mode 100644 index 0000000..cb07467 --- /dev/null +++ b/training/conf/criterion/mae.yaml @@ -0,0 +1,2 @@ +_target_: torch.nn.L1Loss +reduction: mean diff --git a/training/conf/criterion/vqgan_loss.yaml b/training/conf/criterion/vqgan_loss.yaml new file mode 100644 index 0000000..a1c886e --- /dev/null +++ b/training/conf/criterion/vqgan_loss.yaml @@ -0,0 +1,12 @@ +_target_: text_recognizer.criterions.vqgan_loss.VQGANLoss +reconstruction_loss: + _target_: torch.nn.L1Loss + reduction: mean +discriminator: + _target_: text_recognizer.criterions.n_layer_discriminator.NLayerDiscriminator + in_channels: 1 + num_channels: 32 + num_layers: 3 +vq_loss_weight: 1.0 +discriminator_weight: 1.0 + diff --git a/training/conf/experiment/vqgan.yaml b/training/conf/experiment/vqgan.yaml new file mode 100644 index 0000000..3d97892 --- /dev/null +++ b/training/conf/experiment/vqgan.yaml @@ -0,0 +1,55 @@ +# @package _global_ + +defaults: + - override /network: vqvae + - override /criterion: vqgan_loss + - override /model: lit_vqgan + - override /callbacks: wandb_vae + - override /lr_schedulers: null + +datamodule: + batch_size: 8 + +lr_schedulers: + - generator: + T_max: 256 + eta_min: 0.0 + last_epoch: -1 + + interval: epoch + monitor: val/loss + + - discriminator: + T_max: 256 + eta_min: 0.0 + last_epoch: -1 + + interval: epoch + monitor: val/loss + +optimizer: + - generator: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 256 + eta_min: 0.0 + last_epoch: -1 + + interval: epoch + monitor: val/loss + parameters: network + + - discriminator: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 256 + eta_min: 0.0 + last_epoch: -1 + + interval: epoch + monitor: val/loss + parameters: loss_fn + +trainer: + max_epochs: 256 + # gradient_clip_val: 0.25 + +summary: null diff --git a/training/conf/experiment/vqvae.yaml b/training/conf/experiment/vqvae.yaml index 7a9e643..397a039 100644 --- a/training/conf/experiment/vqvae.yaml +++ b/training/conf/experiment/vqvae.yaml @@ -2,17 +2,18 @@ defaults: - override /network: vqvae - - override /criterion: mse + - override /criterion: mae - override /model: lit_vqvae - override /callbacks: wandb_vae - - override /lr_scheduler: cosine_annealing + - override /lr_schedulers: + - cosine_annealing trainer: - max_epochs: 64 + max_epochs: 256 # gradient_clip_val: 0.25 datamodule: - batch_size: 16 + batch_size: 8 # lr_scheduler: # epochs: 64 @@ -21,4 +22,4 @@ datamodule: # optimizer: # lr: 1.0e-3 -summary: [1, 576, 640] +summary: null diff --git a/training/conf/experiment/vqvae_pixelcnn.yaml b/training/conf/experiment/vqvae_pixelcnn.yaml new file mode 100644 index 0000000..4fae782 --- /dev/null +++ b/training/conf/experiment/vqvae_pixelcnn.yaml @@ -0,0 +1,24 @@ +# @package _global_ + +defaults: + - override /network: vqvae_pixelcnn + - override /criterion: mae + - override /model: lit_vqvae + - override /callbacks: wandb_vae + - override /lr_schedulers: + - cosine_annealing + +trainer: + max_epochs: 256 + # gradient_clip_val: 0.25 + +datamodule: + batch_size: 8 + +# lr_scheduler: + # epochs: 64 + # steps_per_epoch: 1245 + +# optimizer: + # lr: 1.0e-3 + diff --git a/training/conf/lr_scheduler/cosine_annealing.yaml b/training/conf/lr_scheduler/cosine_annealing.yaml index 62667bb..c53ee3a 100644 --- a/training/conf/lr_scheduler/cosine_annealing.yaml +++ b/training/conf/lr_scheduler/cosine_annealing.yaml @@ -1,7 +1,8 @@ -_target_: torch.optim.lr_scheduler.CosineAnnealingLR -T_max: 64 -eta_min: 0.0 -last_epoch: -1 +cosine_annealing: + _target_: torch.optim.lr_scheduler.CosineAnnealingLR + T_max: 256 + eta_min: 0.0 + last_epoch: -1 -interval: epoch -monitor: val/loss + interval: epoch + monitor: val/loss diff --git a/training/conf/lr_scheduler/one_cycle.yaml b/training/conf/lr_scheduler/one_cycle.yaml index fb5987a..c60577a 100644 --- a/training/conf/lr_scheduler/one_cycle.yaml +++ b/training/conf/lr_scheduler/one_cycle.yaml @@ -1,19 +1,20 @@ -_target_: torch.optim.lr_scheduler.OneCycleLR -max_lr: 1.0e-3 -total_steps: null -epochs: 512 -steps_per_epoch: 4992 -pct_start: 0.3 -anneal_strategy: cos -cycle_momentum: true -base_momentum: 0.85 -max_momentum: 0.95 -div_factor: 25.0 -final_div_factor: 10000.0 -three_phase: true -last_epoch: -1 -verbose: false +onc_cycle: + _target_: torch.optim.lr_scheduler.OneCycleLR + max_lr: 1.0e-3 + total_steps: null + epochs: 512 + steps_per_epoch: 4992 + pct_start: 0.3 + anneal_strategy: cos + cycle_momentum: true + base_momentum: 0.85 + max_momentum: 0.95 + div_factor: 25.0 + final_div_factor: 10000.0 + three_phase: true + last_epoch: -1 + verbose: false -# Non-class arguments -interval: step -monitor: val/loss + # Non-class arguments + interval: step + monitor: val/loss diff --git a/training/conf/network/encoder/pixelcnn_decoder.yaml b/training/conf/network/decoder/pixelcnn_decoder.yaml index 3895164..cdddb7a 100644 --- a/training/conf/network/encoder/pixelcnn_decoder.yaml +++ b/training/conf/network/decoder/pixelcnn_decoder.yaml @@ -1,5 +1,5 @@ _target_: text_recognizer.networks.vqvae.pixelcnn.Decoder out_channels: 1 hidden_dim: 8 -channels_multipliers: [8, 8, 2, 1] +channels_multipliers: [8, 2, 1] dropout_rate: 0.25 diff --git a/training/conf/network/decoder/vae_decoder.yaml b/training/conf/network/decoder/vae_decoder.yaml index 0a36a54..a5e7286 100644 --- a/training/conf/network/decoder/vae_decoder.yaml +++ b/training/conf/network/decoder/vae_decoder.yaml @@ -1,5 +1,5 @@ _target_: text_recognizer.networks.vqvae.decoder.Decoder out_channels: 1 hidden_dim: 32 -channels_multipliers: [4, 4, 2, 1] +channels_multipliers: [8, 8, 4, 1] dropout_rate: 0.25 diff --git a/training/conf/network/decoder/pixelcnn_encoder.yaml b/training/conf/network/encoder/pixelcnn_encoder.yaml index 47a130d..f12957b 100644 --- a/training/conf/network/decoder/pixelcnn_encoder.yaml +++ b/training/conf/network/encoder/pixelcnn_encoder.yaml @@ -1,5 +1,5 @@ _target_: text_recognizer.networks.vqvae.pixelcnn.Encoder in_channels: 1 hidden_dim: 8 -channels_multipliers: [1, 2, 8, 8] +channels_multipliers: [1, 2, 8] dropout_rate: 0.25 diff --git a/training/conf/network/encoder/vae_encoder.yaml b/training/conf/network/encoder/vae_encoder.yaml index dacd389..58e905d 100644 --- a/training/conf/network/encoder/vae_encoder.yaml +++ b/training/conf/network/encoder/vae_encoder.yaml @@ -1,5 +1,5 @@ _target_: text_recognizer.networks.vqvae.encoder.Encoder in_channels: 1 hidden_dim: 32 -channels_multipliers: [1, 2, 4, 4] +channels_multipliers: [1, 2, 4, 8, 8] dropout_rate: 0.25 diff --git a/training/conf/network/vqvae.yaml b/training/conf/network/vqvae.yaml index d97e9b6..835d0b7 100644 --- a/training/conf/network/vqvae.yaml +++ b/training/conf/network/vqvae.yaml @@ -3,7 +3,7 @@ defaults: - decoder: vae_decoder _target_: text_recognizer.networks.vqvae.vqvae.VQVAE -hidden_dim: 128 +hidden_dim: 256 embedding_dim: 32 num_embeddings: 1024 decay: 0.99 diff --git a/training/conf/network/vqvae_pixelcnn.yaml b/training/conf/network/vqvae_pixelcnn.yaml index 10200bc..cd850af 100644 --- a/training/conf/network/vqvae_pixelcnn.yaml +++ b/training/conf/network/vqvae_pixelcnn.yaml @@ -5,5 +5,5 @@ defaults: _target_: text_recognizer.networks.vqvae.vqvae.VQVAE hidden_dim: 64 embedding_dim: 32 -num_embeddings: 512 +num_embeddings: 1024 decay: 0.99 diff --git a/training/conf/optimizer/madgrad.yaml b/training/conf/optimizer/madgrad.yaml index 458b116..a6c059d 100644 --- a/training/conf/optimizer/madgrad.yaml +++ b/training/conf/optimizer/madgrad.yaml @@ -1,5 +1,8 @@ -_target_: madgrad.MADGRAD -lr: 3.0e-4 -momentum: 0.9 -weight_decay: 0 -eps: 1.0e-6 +madgrad: + _target_: madgrad.MADGRAD + lr: 1.0e-3 + momentum: 0.9 + weight_decay: 0 + eps: 1.0e-6 + + parameters: network diff --git a/training/run.py b/training/run.py index a2529b0..0cf52e3 100644 --- a/training/run.py +++ b/training/run.py @@ -50,8 +50,8 @@ def run(config: DictConfig) -> Optional[float]: mapping=mapping, network=network, loss_fn=loss_fn, - optimizer_config=config.optimizer, - lr_scheduler_config=config.lr_scheduler, + optimizer_configs=config.optimizers, + lr_scheduler_configs=config.lr_schedulers, _recursive_=False, ) |