summaryrefslogtreecommitdiff
path: root/text_recognizer/criterions/label_smoothing.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/criterions/label_smoothing.py')
-rw-r--r--text_recognizer/criterions/label_smoothing.py38
1 files changed, 16 insertions, 22 deletions
diff --git a/text_recognizer/criterions/label_smoothing.py b/text_recognizer/criterions/label_smoothing.py
index 40a7609..cc71c45 100644
--- a/text_recognizer/criterions/label_smoothing.py
+++ b/text_recognizer/criterions/label_smoothing.py
@@ -6,37 +6,31 @@ import torch.nn.functional as F
class LabelSmoothingLoss(nn.Module):
- """Label smoothing cross entropy loss."""
-
- def __init__(
- self, label_smoothing: float, vocab_size: int, ignore_index: int = -100
- ) -> None:
- assert 0.0 < label_smoothing <= 1.0
- self.ignore_index = ignore_index
+ def __init__(self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1):
super().__init__()
+ assert 0.0 < smoothing <= 1.0
+ self.ignore_index = ignore_index
+ self.confidence = 1.0 - smoothing
+ self.smoothing = smoothing
+ self.dim = dim
- smoothing_value = label_smoothing / (vocab_size - 2)
- one_hot = torch.full((vocab_size,), smoothing_value)
- one_hot[self.ignore_index] = 0
- self.register_buffer("one_hot", one_hot.unsqueeze(0))
-
- self.confidence = 1.0 - label_smoothing
-
- def forward(self, output: Tensor, targets: Tensor) -> Tensor:
+ def forward(self, output: Tensor, target: Tensor) -> Tensor:
"""Computes the loss.
Args:
- output (Tensor): Predictions from the network.
+ output (Tensor): outputictions from the network.
targets (Tensor): Ground truth.
Shapes:
- outpus: Batch size x num classes
- targets: Batch size
+ TBC
Returns:
Tensor: Label smoothing loss.
"""
- model_prob = self.one_hot.repeat(targets.size(0), 1)
- model_prob.scatter_(1, targets.unsqueeze(1), self.confidence)
- model_prob.masked_fill_((targets == self.ignore_index).unsqueeze(1), 0)
- return F.kl_div(output, model_prob, reduction="sum")
+ output = output.log_softmax(dim=self.dim)
+ with torch.no_grad():
+ true_dist = torch.zeros_like(output)
+ true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
+ true_dist.masked_fill_((target == 4).unsqueeze(1), 0)
+ true_dist += self.smoothing / output.size(self.dim)
+ return torch.mean(torch.sum(-true_dist * output, dim=self.dim))