From 9b2ecf296b196432a45eca14300e00b78972e44f Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 19 Sep 2021 21:03:20 +0200 Subject: Linting of vqgan loss --- text_recognizer/criterions/vqgan_loss.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/text_recognizer/criterions/vqgan_loss.py b/text_recognizer/criterions/vqgan_loss.py index cfd507d..7af1a55 100644 --- a/text_recognizer/criterions/vqgan_loss.py +++ b/text_recognizer/criterions/vqgan_loss.py @@ -1,6 +1,5 @@ """VQGAN loss for PyTorch Lightning.""" -from typing import Dict, Optional -from click.types import Tuple +from typing import Optional, Tuple import torch from torch import nn, Tensor @@ -9,9 +8,10 @@ import torch.nn.functional as F from text_recognizer.criterions.n_layer_discriminator import NLayerDiscriminator -def adopt_weight( +def _adopt_weight( weight: Tensor, global_step: int, threshold: int = 0, value: float = 0.0 ) -> float: + """Sets weight to value after the threshold is passed.""" if global_step < threshold: weight = value return weight @@ -86,7 +86,7 @@ class VQGANLoss(nn.Module): else: d_weight = torch.tensor(0.0) - d_factor = adopt_weight( + d_factor = _adopt_weight( self.discriminator_factor, global_step=global_step, threshold=self.discriminator_iter_start, @@ -107,7 +107,7 @@ class VQGANLoss(nn.Module): logits_fake = self.discriminator(reconstructions.contiguous().detach()) logits_real = self.discriminator(data.contiguous().detach()) - d_factor = adopt_weight( + d_factor = _adopt_weight( self.discriminator_factor, global_step=global_step, threshold=self.discriminator_iter_start, -- cgit v1.2.3-70-g09d2