summaryrefslogtreecommitdiff
path: root/text_recognizer/models
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models')
-rw-r--r--text_recognizer/models/base.py81
-rw-r--r--text_recognizer/models/transformer.py4
-rw-r--r--text_recognizer/models/vqgan.py135
-rw-r--r--text_recognizer/models/vqvae.py12
4 files changed, 193 insertions, 39 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index ab3fa35..8b68ed9 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -24,8 +24,8 @@ class BaseLitModel(LightningModule):
network: Type[nn.Module] = attr.ib()
mapping: Type[AbstractMapping] = attr.ib()
loss_fn: Type[nn.Module] = attr.ib()
- optimizer_config: DictConfig = attr.ib()
- lr_scheduler_config: DictConfig = attr.ib()
+ optimizer_configs: DictConfig = attr.ib()
+ lr_scheduler_configs: DictConfig = attr.ib()
train_acc: torchmetrics.Accuracy = attr.ib(
init=False, default=torchmetrics.Accuracy()
)
@@ -45,40 +45,55 @@ class BaseLitModel(LightningModule):
) -> None:
optimizer.zero_grad(set_to_none=True)
- def _configure_optimizer(self) -> Type[torch.optim.Optimizer]:
+ def _configure_optimizer(self) -> List[Type[torch.optim.Optimizer]]:
"""Configures the optimizer."""
- log.info(f"Instantiating optimizer <{self.optimizer_config._target_}>")
- return hydra.utils.instantiate(
- self.optimizer_config, params=self.network.parameters()
- )
-
- def _configure_lr_scheduler(
- self, optimizer: Type[torch.optim.Optimizer]
- ) -> Dict[str, Any]:
+ optimizers = []
+ for optimizer_config in self.optimizer_configs.values():
+ network = getattr(self, optimizer_config.parameters)
+ del optimizer_config.parameters
+ log.info(f"Instantiating optimizer <{optimizer_config._target_}>")
+ optimizers.append(
+ hydra.utils.instantiate(
+ self.optimizer_config, params=network.parameters()
+ )
+ )
+ return optimizers
+
+ def _configure_lr_schedulers(
+ self, optimizers: List[Type[torch.optim.Optimizer]]
+ ) -> List[Dict[str, Any]]:
"""Configures the lr scheduler."""
- # Extract non-class arguments.
- monitor = self.lr_scheduler_config.monitor
- interval = self.lr_scheduler_config.interval
- del self.lr_scheduler_config.monitor
- del self.lr_scheduler_config.interval
-
- log.info(
- f"Instantiating learning rate scheduler <{self.lr_scheduler_config._target_}>"
- )
- scheduler = {
- "monitor": monitor,
- "interval": interval,
- "scheduler": hydra.utils.instantiate(
- self.lr_scheduler_config, optimizer=optimizer
- ),
- }
- return scheduler
-
- def configure_optimizers(self) -> Tuple[List[type], List[Dict[str, Any]]]:
+ schedulers = []
+ for optimizer, lr_scheduler_config in zip(
+ optimizers, self.lr_scheduler_configs.values()
+ ):
+ # Extract non-class arguments.
+ monitor = lr_scheduler_config.monitor
+ interval = lr_scheduler_config.interval
+ del lr_scheduler_config.monitor
+ del lr_scheduler_config.interval
+
+ log.info(
+ f"Instantiating learning rate scheduler <{lr_scheduler_config._target_}>"
+ )
+ scheduler = {
+ "monitor": monitor,
+ "interval": interval,
+ "scheduler": hydra.utils.instantiate(
+ lr_scheduler_config, optimizer=optimizer
+ ),
+ }
+ schedulers.append(scheduler)
+
+ return schedulers
+
+ def configure_optimizers(
+ self,
+ ) -> Tuple[List[Type[torch.optim.Optimizer]], List[Dict[str, Any]]]:
"""Configures optimizer and lr scheduler."""
- optimizer = self._configure_optimizer()
- scheduler = self._configure_lr_scheduler(optimizer)
- return [optimizer], [scheduler]
+ optimizers = self._configure_optimizer()
+ schedulers = self._configure_lr_scheduler(optimizers)
+ return optimizers, schedulers
def forward(self, data: Tensor) -> Tensor:
"""Feedforward pass."""
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 5fb84a7..75f7523 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -60,6 +60,8 @@ class TransformerLitModel(BaseLitModel):
pred = self(data)
self.val_cer(pred, targets)
self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
+ self.test_acc(pred, targets)
+ self.log("val/acc", self.test_acc, on_step=False, on_epoch=True)
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Test step."""
@@ -69,6 +71,8 @@ class TransformerLitModel(BaseLitModel):
pred = self(data)
self.test_cer(pred, targets)
self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)
+ self.test_acc(pred, targets)
+ self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)
def predict(self, x: Tensor) -> Tensor:
"""Predicts text in image.
diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py
new file mode 100644
index 0000000..8ff65cc
--- /dev/null
+++ b/text_recognizer/models/vqgan.py
@@ -0,0 +1,135 @@
+"""PyTorch Lightning model for base Transformers."""
+from typing import Tuple
+
+import attr
+from torch import Tensor
+
+from text_recognizer.models.base import BaseLitModel
+from text_recognizer.criterions.vqgan_loss import VQGANLoss
+
+
+@attr.s(auto_attribs=True, eq=False)
+class VQVAELitModel(BaseLitModel):
+ """A PyTorch Lightning model for transformer networks."""
+
+ loss_fn: VQGANLoss = attr.ib()
+ latent_loss_weight: float = attr.ib(default=0.25)
+
+ def forward(self, data: Tensor) -> Tensor:
+ """Forward pass with the transformer network."""
+ return self.network(data)
+
+ def training_step(
+ self, batch: Tuple[Tensor, Tensor], batch_idx: int, optimizer_idx: int
+ ) -> Tensor:
+ """Training step."""
+ data, _ = batch
+
+ reconstructions, vq_loss = self(data)
+ loss = self.loss_fn(reconstructions, data)
+
+ if optimizer_idx == 0:
+ loss, log = self.loss_fn(
+ data=data,
+ reconstructions=reconstructions,
+ vq_loss=vq_loss,
+ optimizer_idx=optimizer_idx,
+ stage="train",
+ )
+ self.log(
+ "train/loss",
+ loss,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ )
+ self.log_dict(log, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return loss
+
+ if optimizer_idx == 1:
+ loss, log = self.loss_fn(
+ data=data,
+ reconstructions=reconstructions,
+ vq_loss=vq_loss,
+ optimizer_idx=optimizer_idx,
+ stage="train",
+ )
+ self.log(
+ "train/discriminator_loss",
+ loss,
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ )
+ self.log_dict(log, prog_bar=False, logger=True, on_step=True, on_epoch=True)
+ return loss
+
+ def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Validation step."""
+ data, _ = batch
+ reconstructions, vq_loss = self(data)
+
+ loss, log = self.loss_fn(
+ data=data,
+ reconstructions=reconstructions,
+ vq_loss=vq_loss,
+ optimizer_idx=0,
+ stage="val",
+ )
+ self.log(
+ "val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True
+ )
+ self.log(
+ "val/rec_loss",
+ log["val/rec_loss"],
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ )
+ self.log_dict(log)
+
+ _, log = self.loss_fn(
+ data=data,
+ reconstructions=reconstructions,
+ vq_loss=vq_loss,
+ optimizer_idx=1,
+ stage="val",
+ )
+ self.log_dict(log)
+
+ def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
+ """Test step."""
+ data, _ = batch
+ reconstructions, vq_loss = self(data)
+
+ loss, log = self.loss_fn(
+ data=data,
+ reconstructions=reconstructions,
+ vq_loss=vq_loss,
+ optimizer_idx=0,
+ stage="test",
+ )
+ self.log(
+ "test/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True
+ )
+ self.log(
+ "test/rec_loss",
+ log["test/rec_loss"],
+ prog_bar=True,
+ logger=True,
+ on_step=True,
+ on_epoch=True,
+ )
+ self.log_dict(log)
+
+ _, log = self.loss_fn(
+ data=data,
+ reconstructions=reconstructions,
+ vq_loss=vq_loss,
+ optimizer_idx=1,
+ stage="test",
+ )
+ self.log_dict(log)
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
index ef9a59a..56229b3 100644
--- a/text_recognizer/models/vqvae.py
+++ b/text_recognizer/models/vqvae.py
@@ -28,8 +28,8 @@ class VQVAELitModel(BaseLitModel):
self.log("train/vq_loss", vq_loss)
self.log("train/loss", loss)
- self.train_acc(reconstructions, data)
- self.log("train/acc", self.train_acc, on_step=False, on_epoch=True)
+ # self.train_acc(reconstructions, data)
+ # self.log("train/acc", self.train_acc, on_step=False, on_epoch=True)
return loss
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
@@ -42,8 +42,8 @@ class VQVAELitModel(BaseLitModel):
self.log("val/vq_loss", vq_loss)
self.log("val/loss", loss, prog_bar=True)
- self.val_acc(reconstructions, data)
- self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
+ # self.val_acc(reconstructions, data)
+ # self.log("val/acc", self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Test step."""
@@ -53,5 +53,5 @@ class VQVAELitModel(BaseLitModel):
loss = loss + self.latent_loss_weight * vq_loss
self.log("test/vq_loss", vq_loss)
self.log("test/loss", loss)
- self.test_acc(reconstructions, data)
- self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)
+ # self.test_acc(reconstructions, data)
+ # self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)