diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-19 21:03:20 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-19 21:03:20 +0200 |
commit | 9b2ecf296b196432a45eca14300e00b78972e44f (patch) | |
tree | 189c7e1dd316de100944a82f0c35fb1b71ffdd1a | |
parent | 5a5228072ffe015e676ba696dd022145a9f44222 (diff) |
Linting of vqgan loss
-rw-r--r-- | text_recognizer/criterions/vqgan_loss.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/text_recognizer/criterions/vqgan_loss.py b/text_recognizer/criterions/vqgan_loss.py index cfd507d..7af1a55 100644 --- a/text_recognizer/criterions/vqgan_loss.py +++ b/text_recognizer/criterions/vqgan_loss.py @@ -1,6 +1,5 @@ """VQGAN loss for PyTorch Lightning.""" -from typing import Dict, Optional -from click.types import Tuple +from typing import Optional, Tuple import torch from torch import nn, Tensor @@ -9,9 +8,10 @@ import torch.nn.functional as F from text_recognizer.criterions.n_layer_discriminator import NLayerDiscriminator -def adopt_weight( +def _adopt_weight( weight: Tensor, global_step: int, threshold: int = 0, value: float = 0.0 ) -> float: + """Sets weight to value after the threshold is passed.""" if global_step < threshold: weight = value return weight @@ -86,7 +86,7 @@ class VQGANLoss(nn.Module): else: d_weight = torch.tensor(0.0) - d_factor = adopt_weight( + d_factor = _adopt_weight( self.discriminator_factor, global_step=global_step, threshold=self.discriminator_iter_start, @@ -107,7 +107,7 @@ class VQGANLoss(nn.Module): logits_fake = self.discriminator(reconstructions.contiguous().detach()) logits_real = self.discriminator(data.contiguous().detach()) - d_factor = adopt_weight( + d_factor = _adopt_weight( self.discriminator_factor, global_step=global_step, threshold=self.discriminator_iter_start, |