summaryrefslogtreecommitdiff
path: root/text_recognizer/criterions
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-19 21:03:20 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-09-19 21:03:20 +0200
commit9b2ecf296b196432a45eca14300e00b78972e44f (patch)
tree189c7e1dd316de100944a82f0c35fb1b71ffdd1a /text_recognizer/criterions
parent5a5228072ffe015e676ba696dd022145a9f44222 (diff)
Linting of vqgan loss
Diffstat (limited to 'text_recognizer/criterions')
-rw-r--r--text_recognizer/criterions/vqgan_loss.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/text_recognizer/criterions/vqgan_loss.py b/text_recognizer/criterions/vqgan_loss.py
index cfd507d..7af1a55 100644
--- a/text_recognizer/criterions/vqgan_loss.py
+++ b/text_recognizer/criterions/vqgan_loss.py
@@ -1,6 +1,5 @@
"""VQGAN loss for PyTorch Lightning."""
-from typing import Dict, Optional
-from click.types import Tuple
+from typing import Optional, Tuple
import torch
from torch import nn, Tensor
@@ -9,9 +8,10 @@ import torch.nn.functional as F
from text_recognizer.criterions.n_layer_discriminator import NLayerDiscriminator
-def adopt_weight(
+def _adopt_weight(
weight: Tensor, global_step: int, threshold: int = 0, value: float = 0.0
) -> float:
+ """Sets weight to value after the threshold is passed."""
if global_step < threshold:
weight = value
return weight
@@ -86,7 +86,7 @@ class VQGANLoss(nn.Module):
else:
d_weight = torch.tensor(0.0)
- d_factor = adopt_weight(
+ d_factor = _adopt_weight(
self.discriminator_factor,
global_step=global_step,
threshold=self.discriminator_iter_start,
@@ -107,7 +107,7 @@ class VQGANLoss(nn.Module):
logits_fake = self.discriminator(reconstructions.contiguous().detach())
logits_real = self.discriminator(data.contiguous().detach())
- d_factor = adopt_weight(
+ d_factor = _adopt_weight(
self.discriminator_factor,
global_step=global_step,
threshold=self.discriminator_iter_start,