summaryrefslogtreecommitdiff
path: root/text_recognizer/criterions
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/criterions')
-rw-r--r--text_recognizer/criterions/vqgan_loss.py53
1 files changed, 49 insertions, 4 deletions
diff --git a/text_recognizer/criterions/vqgan_loss.py b/text_recognizer/criterions/vqgan_loss.py
index 87f0f1c..cfd507d 100644
--- a/text_recognizer/criterions/vqgan_loss.py
+++ b/text_recognizer/criterions/vqgan_loss.py
@@ -9,6 +9,14 @@ import torch.nn.functional as F
from text_recognizer.criterions.n_layer_discriminator import NLayerDiscriminator
+def adopt_weight(
+ weight: Tensor, global_step: int, threshold: int = 0, value: float = 0.0
+) -> float:
+ if global_step < threshold:
+ weight = value
+ return weight
+
+
class VQGANLoss(nn.Module):
"""VQGAN loss."""
@@ -18,12 +26,16 @@ class VQGANLoss(nn.Module):
discriminator: NLayerDiscriminator,
vq_loss_weight: float = 1.0,
discriminator_weight: float = 1.0,
+ discriminator_factor: float = 1.0,
+ discriminator_iter_start: int = 1000,
) -> None:
super().__init__()
self.reconstruction_loss = reconstruction_loss
self.discriminator = discriminator
self.vq_loss_weight = vq_loss_weight
self.discriminator_weight = discriminator_weight
+ self.discriminator_factor = discriminator_factor
+ self.discriminator_iter_start = discriminator_iter_start
@staticmethod
def adversarial_loss(logits_real: Tensor, logits_fake: Tensor) -> Tensor:
@@ -33,12 +45,26 @@ class VQGANLoss(nn.Module):
d_loss = (loss_real + loss_fake) / 2.0
return d_loss
+ def _adaptive_weight(
+ self, rec_loss: Tensor, g_loss: Tensor, decoder_last_layer: Tensor
+ ) -> Tensor:
+ rec_grads = torch.autograd.grad(
+ rec_loss, decoder_last_layer, retain_graph=True
+ )[0]
+ g_grads = torch.autograd.grad(g_loss, decoder_last_layer, retain_graph=True)[0]
+ d_weight = torch.norm(rec_grads) / (torch.norm(g_grads) + 1.0e-4)
+ d_weight = torch.clamp(d_weight, 0.0, 1.0e4).detach()
+ d_weight *= self.discriminator_weight
+ return d_weight
+
def forward(
self,
data: Tensor,
reconstructions: Tensor,
vq_loss: Tensor,
+ decoder_last_layer: Tensor,
optimizer_idx: int,
+ global_step: int,
stage: str,
) -> Optional[Tuple]:
"""Calculates the VQGAN loss."""
@@ -51,10 +77,23 @@ class VQGANLoss(nn.Module):
logits_fake = self.discriminator(reconstructions.contiguous())
g_loss = -torch.mean(logits_fake)
+ if self.training:
+ d_weight = self._adaptive_weight(
+ rec_loss=rec_loss,
+ g_loss=g_loss,
+ decoder_last_layer=decoder_last_layer,
+ )
+ else:
+ d_weight = torch.tensor(0.0)
+
+ d_factor = adopt_weight(
+ self.discriminator_factor,
+ global_step=global_step,
+ threshold=self.discriminator_iter_start,
+ )
+
loss: Tensor = (
- rec_loss
- + self.discriminator_weight * g_loss
- + self.vq_loss_weight * vq_loss
+ rec_loss + d_factor * d_weight * g_loss + self.vq_loss_weight * vq_loss
)
log = {
f"{stage}/total_loss": loss,
@@ -68,7 +107,13 @@ class VQGANLoss(nn.Module):
logits_fake = self.discriminator(reconstructions.contiguous().detach())
logits_real = self.discriminator(data.contiguous().detach())
- d_loss = self.discriminator_weight * self.adversarial_loss(
+ d_factor = adopt_weight(
+ self.discriminator_factor,
+ global_step=global_step,
+ threshold=self.discriminator_iter_start,
+ )
+
+ d_loss = d_factor * self.adversarial_loss(
logits_real=logits_real, logits_fake=logits_fake
)