diff options
Diffstat (limited to 'text_recognizer/models/vqvae.py')
-rw-r--r-- | text_recognizer/models/vqvae.py | 45 |
1 files changed, 0 insertions, 45 deletions
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py deleted file mode 100644 index 4898852..0000000 --- a/text_recognizer/models/vqvae.py +++ /dev/null @@ -1,45 +0,0 @@ -"""PyTorch Lightning model for base Transformers.""" -from typing import Tuple - -import attr -from torch import Tensor - -from text_recognizer.models.base import BaseLitModel - - -@attr.s(auto_attribs=True, eq=False) -class VQVAELitModel(BaseLitModel): - """A PyTorch Lightning model for transformer networks.""" - - commitment: float = attr.ib(default=0.25) - - def forward(self, data: Tensor) -> Tensor: - """Forward pass with the transformer network.""" - return self.network(data) - - def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: - """Training step.""" - data, _ = batch - reconstructions, commitment_loss = self(data) - loss = self.loss_fn(reconstructions, data) - loss = loss + self.commitment * commitment_loss - self.log("train/commitment_loss", commitment_loss) - self.log("train/loss", loss) - return loss - - def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Validation step.""" - data, _ = batch - reconstructions, commitment_loss = self(data) - loss = self.loss_fn(reconstructions, data) - self.log("val/commitment_loss", commitment_loss) - self.log("val/loss", loss, prog_bar=True) - - def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: - """Test step.""" - data, _ = batch - reconstructions, commitment_loss = self(data) - loss = self.loss_fn(reconstructions, data) - loss = loss + self.commitment * commitment_loss - self.log("test/commitment_loss", commitment_loss) - self.log("test/loss", loss) |