summaryrefslogtreecommitdiff
path: root/text_recognizer/models/vqgan.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/vqgan.py')
-rw-r--r--text_recognizer/models/vqgan.py135
1 files changed, 135 insertions, 0 deletions
diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py
new file mode 100644
index 0000000..8ff65cc
--- /dev/null
+++ b/text_recognizer/models/vqgan.py
@@ -0,0 +1,135 @@
+"""PyTorch Lightning model for base Transformers."""
+from typing import Tuple
+
+import attr
+from torch import Tensor
+
+from text_recognizer.models.base import BaseLitModel
+from text_recognizer.criterions.vqgan_loss import VQGANLoss
+
+
+@attr.s(auto_attribs=True, eq=False)
+class VQVAELitModel(BaseLitModel):
+ """A PyTorch Lightning model for transformer networks."""
+
+ loss_fn: VQGANLoss = attr.ib()
+ latent_loss_weight: float = attr.ib(default=0.25)
+
+ def forward(self, data: Tensor) -> Tensor:
+ """Forward pass with the transformer network."""
+ return self.network(data)
+
+ def training_step(
+ self, batch: Tuple[Tensor, Tensor], batch_idx: int, optimizer_idx: int
+ ) -> Tensor:
+ """Training step."""
+ data, _ = batch
+
+ reconstructions, vq_loss = self(data)
+ loss = self.loss_fn(reconstructions, data)
+
+ if optimizer_idx == 0:
+ loss, log = self.loss_fn(
+ data=data,
+ reconstructions=reconstructions,
+ vq_loss=vq_loss,
+ optimizer_idx=optimizer_idx,
+ stage="train",
+ )
+ self.log(
+ "train/loss",
+ loss,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ )
+ self.log_dict(log, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return loss
+
+ if optimizer_idx == 1:
+ loss, log = self.loss_fn(
+ data=data,
+ reconstructions=reconstructions,
+ vq_loss=vq_loss,
+ optimizer_idx=optimizer_idx,
+ stage="train",
+ )
+ self.log(
+ "train/discriminator_loss",
+ loss,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ )
+ self.log_dict(log, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return loss
+
+ def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Validation step."""
+ data, _ = batch
+ reconstructions, vq_loss = self(data)
+
+ loss, log = self.loss_fn(
+ data=data,
+ reconstructions=reconstructions,
+ vq_loss=vq_loss,
+ optimizer_idx=0,
+ stage="val",
+ )
+ self.log(
+ "val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True
+ )
+ self.log(
+ "val/rec_loss",
+ log["val/rec_loss"],
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ )
+ self.log_dict(log)
+
+ _, log = self.loss_fn(
+ data=data,
+ reconstructions=reconstructions,
+ vq_loss=vq_loss,
+ optimizer_idx=1,
+ stage="val",
+ )
+ self.log_dict(log)
+
+ def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Test step."""
+ data, _ = batch
+ reconstructions, vq_loss = self(data)
+
+ loss, log = self.loss_fn(
+ data=data,
+ reconstructions=reconstructions,
+ vq_loss=vq_loss,
+ optimizer_idx=0,
+ stage="test",
+ )
+ self.log(
+ "test/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True
+ )
+ self.log(
+ "test/rec_loss",
+ log["test/rec_loss"],
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ )
+ self.log_dict(log)
+
+ _, log = self.loss_fn(
+ data=data,
+ reconstructions=reconstructions,
+ vq_loss=vq_loss,
+ optimizer_idx=1,
+ stage="test",
+ )
+ self.log_dict(log)