summaryrefslogtreecommitdiff
path: root/text_recognizer/models
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-22 08:15:58 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-04-22 08:15:58 +0200
commit1ca8b0b9e0613c1e02f6a5d8b49e20c4d6916412 (patch)
tree5e610ac459c9b254f8826e92372346f01f8e2412 /text_recognizer/models
parentffa4be4bf4e3758e01d52a9c1f354a05a90b93de (diff)
Fixed training script, able to train vqvae
Diffstat (limited to 'text_recognizer/models')
-rw-r--r--text_recognizer/models/__init__.py3
-rw-r--r--text_recognizer/models/base.py9
-rw-r--r--text_recognizer/models/vqvae.py70
3 files changed, 82 insertions, 0 deletions
diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py
index e69de29..5ac2510 100644
--- a/text_recognizer/models/__init__.py
+++ b/text_recognizer/models/__init__.py
@@ -0,0 +1,3 @@
+"""PyTorch Lightning models modules."""
+from .transformer import LitTransformerModel
+from .vqvae import LitVQVAEModel
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index aeda039..88ffde6 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -40,6 +40,15 @@ class LitBaseModel(pl.LightningModule):
args = {} or criterion.args
return getattr(nn, criterion.type)(**args)
+ def optimizer_zero_grad(
+ self,
+ epoch: int,
+ batch_idx: int,
+ optimizer: Type[torch.optim.Optimizer],
+ optimizer_idx: int,
+ ) -> None:
+ optimizer.zero_grad(set_to_none=True)
+
def _configure_optimizer(self) -> Type[torch.optim.Optimizer]:
"""Configures the optimizer."""
args = {} or self._optimizer.args
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
new file mode 100644
index 0000000..ef2213c
--- /dev/null
+++ b/text_recognizer/models/vqvae.py
@@ -0,0 +1,70 @@
+"""PyTorch Lightning model for base Transformers."""
+from typing import Any, Dict, Union, Tuple, Type
+
+from omegaconf import DictConfig, OmegaConf
+from torch import nn
+from torch import Tensor
+import torch.nn.functional as F
+import wandb
+
+from text_recognizer.models.base import LitBaseModel
+
+
+class LitVQVAEModel(LitBaseModel):
+ """A PyTorch Lightning model for transformer networks."""
+
+ def __init__(
+ self,
+ network: Type[nn.Module],
+ optimizer: Union[DictConfig, Dict],
+ lr_scheduler: Union[DictConfig, Dict],
+ criterion: Union[DictConfig, Dict],
+ monitor: str = "val_loss",
+ *args: Any,
+ **kwargs: Dict,
+ ) -> None:
+ super().__init__(network, optimizer, lr_scheduler, criterion, monitor)
+
+ def forward(self, data: Tensor) -> Tensor:
+ """Forward pass with the transformer network."""
+ return self.network.predict(data)
+
+ def _log_prediction(self, data: Tensor, reconstructions: Tensor) -> None:
+ """Logs prediction on image with wandb."""
+ try:
+ self.logger.experiment.log(
+ {
+ "val_pred_examples": [
+ wandb.Image(data[0]),
+ wandb.Image(reconstructions[0]),
+ ]
+ }
+ )
+ except AttributeError:
+ pass
+
+ def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
+ """Training step."""
+ data, _ = batch
+ reconstructions, vq_loss = self.network(data)
+ loss = self.loss_fn(reconstructions, data)
+ loss += vq_loss
+ self.log("train_loss", loss)
+ return loss
+
+ def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Validation step."""
+ data, _ = batch
+ reconstructions, vq_loss = self.network(data)
+ loss = self.loss_fn(reconstructions, data)
+ loss += vq_loss
+ self.log("val_loss", loss, prog_bar=True)
+ self._log_prediction(data, reconstructions)
+
+ def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Test step."""
+ data, _ = batch
+ reconstructions, vq_loss = self.network(data)
+ loss = self.loss_fn(reconstructions, data)
+ loss += vq_loss
+ self._log_prediction(data, reconstructions)