summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-08 21:43:39 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-08-08 21:43:39 +0200
commit82f4acabe24e5171c40afa2939a4777ba87bcc30 (patch)
tree4d327fa26e4662a0447a66375442a9adeb13ea3d /text_recognizer
parent240f5e9f20032e82515fa66ce784619527d1041e (diff)
Add training of VQGAN
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/criterions/vqgan_loss.py23
-rw-r--r--text_recognizer/models/base.py8
-rw-r--r--text_recognizer/models/vqgan.py24
3 files changed, 15 insertions, 40 deletions
diff --git a/text_recognizer/criterions/vqgan_loss.py b/text_recognizer/criterions/vqgan_loss.py
index 8bb568f..87f0f1c 100644
--- a/text_recognizer/criterions/vqgan_loss.py
+++ b/text_recognizer/criterions/vqgan_loss.py
@@ -1,5 +1,5 @@
"""VQGAN loss for PyTorch Lightning."""
-from typing import Dict
+from typing import Dict, Optional
from click.types import Tuple
import torch
@@ -40,9 +40,9 @@ class VQGANLoss(nn.Module):
vq_loss: Tensor,
optimizer_idx: int,
stage: str,
- ) -> Tuple[Tensor, Dict[str, Tensor]]:
+ ) -> Optional[Tuple]:
"""Calculates the VQGAN loss."""
- rec_loss = self.reconstruction_loss(
+ rec_loss: Tensor = self.reconstruction_loss(
data.contiguous(), reconstructions.contiguous()
)
@@ -51,13 +51,13 @@ class VQGANLoss(nn.Module):
logits_fake = self.discriminator(reconstructions.contiguous())
g_loss = -torch.mean(logits_fake)
- loss = (
+ loss: Tensor = (
rec_loss
+ self.discriminator_weight * g_loss
+ self.vq_loss_weight * vq_loss
)
log = {
- f"{stage}/loss": loss,
+ f"{stage}/total_loss": loss,
f"{stage}/vq_loss": vq_loss,
f"{stage}/rec_loss": rec_loss,
f"{stage}/g_loss": g_loss,
@@ -68,18 +68,11 @@ class VQGANLoss(nn.Module):
logits_fake = self.discriminator(reconstructions.contiguous().detach())
logits_real = self.discriminator(data.contiguous().detach())
- d_loss = self.adversarial_loss(
+ d_loss = self.discriminator_weight * self.adversarial_loss(
logits_real=logits_real, logits_fake=logits_fake
)
- loss = (
- rec_loss
- + self.discriminator_weight * d_loss
- + self.vq_loss_weight * vq_loss
- )
+
log = {
- f"{stage}/loss": loss,
- f"{stage}/vq_loss": vq_loss,
- f"{stage}/rec_loss": rec_loss,
f"{stage}/d_loss": d_loss,
}
- return loss, log
+ return d_loss, log
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 8b68ed9..94dbde5 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -49,12 +49,14 @@ class BaseLitModel(LightningModule):
"""Configures the optimizer."""
optimizers = []
for optimizer_config in self.optimizer_configs.values():
- network = getattr(self, optimizer_config.parameters)
+ module = self
+ for m in str(optimizer_config.parameters).split("."):
+ module = getattr(module, m)
del optimizer_config.parameters
log.info(f"Instantiating optimizer <{optimizer_config._target_}>")
optimizers.append(
hydra.utils.instantiate(
- self.optimizer_config, params=network.parameters()
+ optimizer_config, params=module.parameters()
)
)
return optimizers
@@ -92,7 +94,7 @@ class BaseLitModel(LightningModule):
) -> Tuple[List[Type[torch.optim.Optimizer]], List[Dict[str, Any]]]:
"""Configures optimizer and lr scheduler."""
optimizers = self._configure_optimizer()
- schedulers = self._configure_lr_scheduler(optimizers)
+ schedulers = self._configure_lr_schedulers(optimizers)
return optimizers, schedulers
def forward(self, data: Tensor) -> Tensor:
diff --git a/text_recognizer/models/vqgan.py b/text_recognizer/models/vqgan.py
index 8ff65cc..80653b6 100644
--- a/text_recognizer/models/vqgan.py
+++ b/text_recognizer/models/vqgan.py
@@ -9,7 +9,7 @@ from text_recognizer.criterions.vqgan_loss import VQGANLoss
@attr.s(auto_attribs=True, eq=False)
-class VQVAELitModel(BaseLitModel):
+class VQGANLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
loss_fn: VQGANLoss = attr.ib()
@@ -26,7 +26,6 @@ class VQVAELitModel(BaseLitModel):
data, _ = batch
reconstructions, vq_loss = self(data)
- loss = self.loss_fn(reconstructions, data)
if optimizer_idx == 0:
loss, log = self.loss_fn(
@@ -81,14 +80,6 @@ class VQVAELitModel(BaseLitModel):
self.log(
"val/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True
)
- self.log(
- "val/rec_loss",
- log["val/rec_loss"],
- prog_bar=True,
- logger=True,
- on_step=True,
- on_epoch=True,
- )
self.log_dict(log)
_, log = self.loss_fn(
@@ -105,24 +96,13 @@ class VQVAELitModel(BaseLitModel):
data, _ = batch
reconstructions, vq_loss = self(data)
- loss, log = self.loss_fn(
+ _, log = self.loss_fn(
data=data,
reconstructions=reconstructions,
vq_loss=vq_loss,
optimizer_idx=0,
stage="test",
)
- self.log(
- "test/loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True
- )
- self.log(
- "test/rec_loss",
- log["test/rec_loss"],
- prog_bar=True,
- logger=True,
- on_step=True,
- on_epoch=True,
- )
self.log_dict(log)
_, log = self.loss_fn(