summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/criterions/vqgan_loss.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/text_recognizer/criterions/vqgan_loss.py b/text_recognizer/criterions/vqgan_loss.py
index cfd507d..7af1a55 100644
--- a/text_recognizer/criterions/vqgan_loss.py
+++ b/text_recognizer/criterions/vqgan_loss.py
@@ -1,6 +1,5 @@
"""VQGAN loss for PyTorch Lightning."""
-from typing import Dict, Optional
-from click.types import Tuple
+from typing import Optional, Tuple
import torch
from torch import nn, Tensor
@@ -9,9 +8,10 @@ import torch.nn.functional as F
from text_recognizer.criterions.n_layer_discriminator import NLayerDiscriminator
-def adopt_weight(
+def _adopt_weight(
weight: Tensor, global_step: int, threshold: int = 0, value: float = 0.0
) -> float:
+ """Sets weight to value after the threshold is passed."""
if global_step < threshold:
weight = value
return weight
@@ -86,7 +86,7 @@ class VQGANLoss(nn.Module):
else:
d_weight = torch.tensor(0.0)
- d_factor = adopt_weight(
+ d_factor = _adopt_weight(
self.discriminator_factor,
global_step=global_step,
threshold=self.discriminator_iter_start,
@@ -107,7 +107,7 @@ class VQGANLoss(nn.Module):
logits_fake = self.discriminator(reconstructions.contiguous().detach())
logits_real = self.discriminator(data.contiguous().detach())
- d_factor = adopt_weight(
+ d_factor = _adopt_weight(
self.discriminator_factor,
global_step=global_step,
threshold=self.discriminator_iter_start,