From 240f5e9f20032e82515fa66ce784619527d1041e Mon Sep 17 00:00:00 2001
From: Gustaf Rydholm <gustaf.rydholm@gmail.com>
Date: Sun, 8 Aug 2021 19:59:55 +0200
Subject: Add VQGAN and loss function

---
 text_recognizer/models/base.py        |  81 +++++++++++---------
 text_recognizer/models/transformer.py |   4 +
 text_recognizer/models/vqgan.py       | 135 ++++++++++++++++++++++++++++++++++
 text_recognizer/models/vqvae.py       |  12 +--
 4 files changed, 193 insertions(+), 39 deletions(-)
 create mode 100644 text_recognizer/models/vqgan.py

(limited to 'text_recognizer/models')

diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index ab3fa35..8b68ed9 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -24,8 +24,8 @@ class BaseLitModel(LightningModule):
     network: Type[nn.Module] = attr.ib()
     mapping: Type[AbstractMapping] = attr.ib()
     loss_fn: Type[nn.Module] = attr.ib()
-    optimizer_config: DictConfig = attr.ib()
-    lr_scheduler_config: DictConfig = attr.ib()
+    optimizer_configs: DictConfig = attr.ib()
+    lr_scheduler_configs: DictConfig = attr.ib()
     train_acc: torchmetrics.Accuracy = attr.ib(
         init=False, default=torchmetrics.Accuracy()
     )
@@ -45,40 +45,55 @@ class BaseLitModel(LightningModule):
     ) -> None:
         optimizer.zero_grad(set_to_none=True)
 
-    def _configure_optimizer(self) -> Type[torch.optim.Optimizer]:
+    def _configure_optimizer(self) -> List[Type[torch.optim.Optimizer]]:
         """Configures the optimizer."""
-        log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>")
-        return hydra.utils.instantiate(
-            self.optimizer_config, params=self.network.parameters()
-        )
-
-    def _configure_lr_scheduler(
-        self, optimizer: Type[torch.optim.Optimizer]
-    ) -> Dict[str, Any]:
+        optimizers = []
+        for optimizer_config in self.optimizer_configs.values():
+            network = getattr(self, optimizer_config.parameters)
+            del optimizer_config.parameters
+            log.info(f"Instantiating optimizer <{optimizer_config._target_}>")
+            optimizers.append(
+                hydra.utils.instantiate(
+                    self.optimizer_config, params=network.parameters()
+                )
+            )
+        return optimizers
+
+    def _configure_lr_schedulers(
+        self, optimizers: List[Type[torch.optim.Optimizer]]
+    ) -> List[Dict[str, Any]]:
         """Configures the lr scheduler."""
-        # Extract non-class arguments.
-        monitor = self.lr_scheduler_config.monitor
-        interval = self.lr_scheduler_config.interval
-        del self.lr_scheduler_config.monitor
-        del self.lr_scheduler_config.interval
-
-        log.info(
-            f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>"
-        )
-        scheduler = {
-            "monitor": monitor,
-            "interval": interval,
-            "scheduler": hydra.utils.instantiate(
-                self.lr_scheduler_config, optimizer=optimizer
-            ),
-        }
-        return scheduler
-
-    def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]:
+        schedulers = []
+        for optimizer, lr_scheduler_config in zip(
+            optimizers, self.lr_scheduler_configs.values()
+        ):
+            # Extract non-class arguments.
+            monitor = lr_scheduler_config.monitor
+            interval = lr_scheduler_config.interval
+            del lr_scheduler_config.monitor
+            del lr_scheduler_config.interval
+
+            log.info(
+                f"Instantiating learning rate scheduler <{lr_scheduler_config._target_}>"
+            )
+            scheduler = {
+                "monitor": monitor,
+                "interval": interval,
+                "scheduler": hydra.utils.instantiate(
+                    lr_scheduler_config, optimizer=optimizer
+                ),
+            }
+            schedulers.append(scheduler)
+
+        return schedulers
+
+    def configure_optimizers(
+        self,
+    ) -> Tuple[List[Type[torch.optim.Optimizer]], List[Dict[str, Any]]]:
         """Configures optimizer and lr scheduler."""
-        optimizer = self._configure_optimizer()
-        scheduler = self._configure_lr_scheduler(optimizer)
-        return [optimizer], [scheduler]
+        optimizers = self._configure_optimizer()
+        schedulers = self._configure_lr_scheduler(optimizers)
+        return optimizers, schedulers
 
     def forward(self, data: Tensor) -> Tensor:
         """Feedforward pass."""
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 5fb84a7..75f7523 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -60,6 +60,8 @@ class TransformerLitModel(BaseLitModel):
         pred = self(data)
         self.val_cer(pred, targets)
         self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
+        self.test_acc(pred, targets)
+        self.log("val/acc", self.test_acc, on_step=False, on_epoch=True)
 
     def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
         """Test step."""
@@ -69,6 +71,8 @@ class TransformerLitModel(BaseLitModel):
         pred = self(data)
         self.test_cer(pred, targets)
         self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)
+        self.test_acc(pred, targets)
+        self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)
 
     def predict(self, x: Tensor) -> Tensor:
         """Predicts text in image.
diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py
new file mode 100644
index 0000000..8ff65cc
--- /dev/null
+++ b/text_recognizer/models/vqgan.py
@@ -0,0 +1,135 @@
+"""PyTorch Lightning model for base Transformers."""
+from typing import Tuple
+
+import attr
+from torch import Tensor
+
+from text_recognizer.models.base import BaseLitModel
+from text_recognizer.criterions.vqgan_loss import VQGANLoss
+
+
+@attr.s(auto_attribs=True, eq=False)
+class VQVAELitModel(BaseLitModel):
+    """A PyTorch Lightning model for transformer networks."""
+
+    loss_fn: VQGANLoss = attr.ib()
+    latent_loss_weight: float = attr.ib(default=0.25)
+
+    def forward(self, data: Tensor) -> Tensor:
+        """Forward pass with the transformer network."""
+        return self.network(data)
+
+    def training_step(
+        self, batch: Tuple[Tensor, Tensor], batch_idx: int, optimizer_idx: int
+    ) -> Tensor:
+        """Training step."""
+        data, _ = batch
+
+        reconstructions, vq_loss = self(data)
+        loss = self.loss_fn(reconstructions, data)
+
+        if optimizer_idx == 0:
+            loss, log = self.loss_fn(
+                data=data,
+                reconstructions=reconstructions,
+                vq_loss=vq_loss,
+                optimizer_idx=optimizer_idx,
+                stage="train",
+            )
+            self.log(
+                "train/loss",
+                loss,
+                prog_bar=True,
+                logger=True,
+                on_step=True,
+                on_epoch=True,
+            )
+            self.log_dict(log, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+            return loss
+
+        if optimizer_idx == 1:
+            loss, log = self.loss_fn(
+                data=data,
+                reconstructions=reconstructions,
+                vq_loss=vq_loss,
+                optimizer_idx=optimizer_idx,
+                stage="train",
+            )
+            self.log(
+                "train/discriminator_loss",
+                loss,
+                prog_bar=True,
+                logger=True,
+                on_step=True,
+                on_epoch=True,
+            )
+            self.log_dict(log, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+            return loss
+
+    def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+        """Validation step."""
+        data, _ = batch
+        reconstructions, vq_loss = self(data)
+
+        loss, log = self.loss_fn(
+            data=data,
+            reconstructions=reconstructions,
+            vq_loss=vq_loss,
+            optimizer_idx=0,
+            stage="val",
+        )
+        self.log(
+            "val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True
+        )
+        self.log(
+            "val/rec_loss",
+            log["val/rec_loss"],
+            prog_bar=True,
+            logger=True,
+            on_step=True,
+            on_epoch=True,
+        )
+        self.log_dict(log)
+
+        _, log = self.loss_fn(
+            data=data,
+            reconstructions=reconstructions,
+            vq_loss=vq_loss,
+            optimizer_idx=1,
+            stage="val",
+        )
+        self.log_dict(log)
+
+    def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+        """Test step."""
+        data, _ = batch
+        reconstructions, vq_loss = self(data)
+
+        loss, log = self.loss_fn(
+            data=data,
+            reconstructions=reconstructions,
+            vq_loss=vq_loss,
+            optimizer_idx=0,
+            stage="test",
+        )
+        self.log(
+            "test/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True
+        )
+        self.log(
+            "test/rec_loss",
+            log["test/rec_loss"],
+            prog_bar=True,
+            logger=True,
+            on_step=True,
+            on_epoch=True,
+        )
+        self.log_dict(log)
+
+        _, log = self.loss_fn(
+            data=data,
+            reconstructions=reconstructions,
+            vq_loss=vq_loss,
+            optimizer_idx=1,
+            stage="test",
+        )
+        self.log_dict(log)
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
index ef9a59a..56229b3 100644
--- a/text_recognizer/models/vqvae.py
+++ b/text_recognizer/models/vqvae.py
@@ -28,8 +28,8 @@ class VQVAELitModel(BaseLitModel):
         self.log("train/vq_loss", vq_loss)
         self.log("train/loss", loss)
 
-        self.train_acc(reconstructions, data)
-        self.log("train/acc", self.train_acc, on_step=False, on_epoch=True)
+        # self.train_acc(reconstructions, data)
+        # self.log("train/acc", self.train_acc, on_step=False, on_epoch=True)
         return loss
 
     def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
@@ -42,8 +42,8 @@ class VQVAELitModel(BaseLitModel):
         self.log("val/vq_loss", vq_loss)
         self.log("val/loss", loss, prog_bar=True)
 
-        self.val_acc(reconstructions, data)
-        self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
+        # self.val_acc(reconstructions, data)
+        # self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
 
     def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
         """Test step."""
@@ -53,5 +53,5 @@ class VQVAELitModel(BaseLitModel):
         loss = loss + self.latent_loss_weight * vq_loss
         self.log("test/vq_loss", vq_loss)
         self.log("test/loss", loss)
-        self.test_acc(reconstructions, data)
-        self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)
+        # self.test_acc(reconstructions, data)
+        # self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)
-- 
cgit v1.2.3-70-g09d2