diff options
-rw-r--r-- | text_recognizer/criterions/vqgan_loss.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/text_recognizer/criterions/vqgan_loss.py b/text_recognizer/criterions/vqgan_loss.py index cfd507d..7af1a55 100644 --- a/text_recognizer/criterions/vqgan_loss.py +++ b/text_recognizer/criterions/vqgan_loss.py @@ -1,6 +1,5 @@ """VQGAN loss for PyTorch Lightning.""" -from typing import Dict, Optional -from click.types import Tuple +from typing import Optional, Tuple import torch from torch import nn, Tensor @@ -9,9 +8,10 @@ import torch.nn.functional as F from text_recognizer.criterions.n_layer_discriminator import NLayerDiscriminator -def adopt_weight( +def _adopt_weight( weight: Tensor, global_step: int, threshold: int = 0, value: float = 0.0 ) -> float: + """Sets weight to value after the threshold is passed.""" if global_step < threshold: weight = value return weight @@ -86,7 +86,7 @@ class VQGANLoss(nn.Module): else: d_weight = torch.tensor(0.0) - d_factor = adopt_weight( + d_factor = _adopt_weight( self.discriminator_factor, global_step=global_step, threshold=self.discriminator_iter_start, @@ -107,7 +107,7 @@ class VQGANLoss(nn.Module): logits_fake = self.discriminator(reconstructions.contiguous().detach()) logits_real = self.discriminator(data.contiguous().detach()) - d_factor = adopt_weight( + d_factor = _adopt_weight( self.discriminator_factor, global_step=global_step, threshold=self.discriminator_iter_start, |