summaryrefslogtreecommitdiff
path: root/text_recognizer/models/vqvae.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/vqvae.py')
-rw-r--r--text_recognizer/models/vqvae.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
new file mode 100644
index 0000000..ef2213c
--- /dev/null
+++ b/text_recognizer/models/vqvae.py
@@ -0,0 +1,70 @@
+"""PyTorch Lightning model for base Transformers."""
+from typing import Any, Dict, Union, Tuple, Type
+
+from omegaconf import DictConfig, OmegaConf
+from torch import nn
+from torch import Tensor
+import torch.nn.functional as F
+import wandb
+
+from text_recognizer.models.base import LitBaseModel
+
+
+class LitVQVAEModel(LitBaseModel):
+ """A PyTorch Lightning model for transformer networks."""
+
+ def __init__(
+ self,
+ network: Type[nn.Module],
+ optimizer: Union[DictConfig, Dict],
+ lr_scheduler: Union[DictConfig, Dict],
+ criterion: Union[DictConfig, Dict],
+ monitor: str = "val_loss",
+ *args: Any,
+ **kwargs: Dict,
+ ) -> None:
+ super().__init__(network, optimizer, lr_scheduler, criterion, monitor)
+
+ def forward(self, data: Tensor) -> Tensor:
+ """Forward pass with the transformer network."""
+ return self.network.predict(data)
+
+ def _log_prediction(self, data: Tensor, reconstructions: Tensor) -> None:
+ """Logs prediction on image with wandb."""
+ try:
+ self.logger.experiment.log(
+ {
+ "val_pred_examples": [
+ wandb.Image(data[0]),
+ wandb.Image(reconstructions[0]),
+ ]
+ }
+ )
+ except AttributeError:
+ pass
+
+ def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
+ """Training step."""
+ data, _ = batch
+ reconstructions, vq_loss = self.network(data)
+ loss = self.loss_fn(reconstructions, data)
+ loss += vq_loss
+ self.log("train_loss", loss)
+ return loss
+
+ def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Validation step."""
+ data, _ = batch
+ reconstructions, vq_loss = self.network(data)
+ loss = self.loss_fn(reconstructions, data)
+ loss += vq_loss
+ self.log("val_loss", loss, prog_bar=True)
+ self._log_prediction(data, reconstructions)
+
+ def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Test step."""
+ data, _ = batch
+ reconstructions, vq_loss = self.network(data)
+ loss = self.loss_fn(reconstructions, data)
+ loss += vq_loss
+ self._log_prediction(data, reconstructions)