diff options
Diffstat (limited to 'text_recognizer/models/vqgan.py')
-rw-r--r-- | text_recognizer/models/vqgan.py | 116 |
1 files changed, 0 insertions, 116 deletions
diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py deleted file mode 100644 index 6a90e06..0000000 --- a/text_recognizer/models/vqgan.py +++ /dev/null @@ -1,116 +0,0 @@ -"""PyTorch Lightning model for base Transformers.""" -from typing import Tuple - -import attr -from torch import Tensor - -from text_recognizer.criterion.vqgan_loss import VQGANLoss -from text_recognizer.models.base import BaseLitModel - - -@attr.s(auto_attribs=True, eq=False) -class VQGANLitModel(BaseLitModel): - """A PyTorch Lightning model for transformer networks.""" - - loss_fn: VQGANLoss = attr.ib() - latent_loss_weight: float = attr.ib(default=0.25) - - def forward(self, data: Tensor) -> Tensor: - """Forward pass with the transformer network.""" - return self.network(data) - - def training_step( - self, batch: Tuple[Tensor, Tensor], batch_idx: int, optimizer_idx: int - ) -> Tensor: - """Training step.""" - data, _ = batch - reconstructions, commitment_loss = self(data) - - if optimizer_idx == 0: - loss, log = self.loss_fn( - data=data, - reconstructions=reconstructions, - commitment_loss=commitment_loss, - decoder_last_layer=self.network.decoder.decoder[-1].weight, - optimizer_idx=optimizer_idx, - global_step=self.global_step, - stage="train", - ) - self.log( - "train/loss", loss, prog_bar=True, - ) - self.log_dict(log, logger=True, on_step=True, on_epoch=True) - return loss - - if optimizer_idx == 1: - loss, log = self.loss_fn( - data=data, - reconstructions=reconstructions, - commitment_loss=commitment_loss, - decoder_last_layer=self.network.decoder.decoder[-1].weight, - optimizer_idx=optimizer_idx, - global_step=self.global_step, - stage="train", - ) - self.log( - "train/discriminator_loss", loss, prog_bar=True, - ) - self.log_dict(log, logger=True, on_step=True, on_epoch=True) - return loss - - def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Validation step.""" - data, _ = batch - reconstructions, commitment_loss = self(data) - - loss, log = self.loss_fn( - data=data, - reconstructions=reconstructions, - commitment_loss=commitment_loss, - decoder_last_layer=self.network.decoder.decoder[-1].weight, - optimizer_idx=0, - global_step=self.global_step, - stage="val", - ) - self.log( - "val/loss", loss, prog_bar=True, - ) - self.log_dict(log) - - _, log = self.loss_fn( - data=data, - reconstructions=reconstructions, - commitment_loss=commitment_loss, - decoder_last_layer=self.network.decoder.decoder[-1].weight, - optimizer_idx=1, - global_step=self.global_step, - stage="val", - ) - self.log_dict(log) - - def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Test step.""" - data, _ = batch - reconstructions, commitment_loss = self(data) - - _, log = self.loss_fn( - data=data, - reconstructions=reconstructions, - commitment_loss=commitment_loss, - decoder_last_layer=self.network.decoder.decoder[-1].weight, - optimizer_idx=0, - global_step=self.global_step, - stage="test", - ) - self.log_dict(log) - - _, log = self.loss_fn( - data=data, - reconstructions=reconstructions, - commitment_loss=commitment_loss, - decoder_last_layer=self.network.decoder.decoder[-1].weight, - optimizer_idx=1, - global_step=self.global_step, - stage="test", - ) - self.log_dict(log) |