summaryrefslogtreecommitdiff
path: root/text_recognizer/models/vqvae.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/vqvae.py')
-rw-r--r--text_recognizer/models/vqvae.py45
1 files changed, 0 insertions, 45 deletions
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
deleted file mode 100644
index 4898852..0000000
--- a/text_recognizer/models/vqvae.py
+++ /dev/null
@@ -1,45 +0,0 @@
-"""PyTorch Lightning model for base Transformers."""
-from typing import Tuple
-
-import attr
-from torch import Tensor
-
-from text_recognizer.models.base import BaseLitModel
-
-
-@attr.s(auto_attribs=True, eq=False)
-class VQVAELitModel(BaseLitModel):
- """A PyTorch Lightning model for transformer networks."""
-
- commitment: float = attr.ib(default=0.25)
-
- def forward(self, data: Tensor) -> Tensor:
- """Forward pass with the transformer network."""
- return self.network(data)
-
- def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
- """Training step."""
- data, _ = batch
- reconstructions, commitment_loss = self(data)
- loss = self.loss_fn(reconstructions, data)
- loss = loss + self.commitment * commitment_loss
- self.log("train/commitment_loss", commitment_loss)
- self.log("train/loss", loss)
- return loss
-
- def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Validation step."""
- data, _ = batch
- reconstructions, commitment_loss = self(data)
- loss = self.loss_fn(reconstructions, data)
- self.log("val/commitment_loss", commitment_loss)
- self.log("val/loss", loss, prog_bar=True)
-
- def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
- """Test step."""
- data, _ = batch
- reconstructions, commitment_loss = self(data)
- loss = self.loss_fn(reconstructions, data)
- loss = loss + self.commitment * commitment_loss
- self.log("test/commitment_loss", commitment_loss)
- self.log("test/loss", loss)