summaryrefslogtreecommitdiff
path: root/text_recognizer/criterion/label_smoothing.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-27 22:41:39 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-27 22:41:39 +0200
commitbec4aafe707be8e5763ad6b2194d4589f20594a9 (patch)
tree506517ca6a17241a305114e787d1b899a48a3d86 /text_recognizer/criterion/label_smoothing.py
parent9a8044f4a3826a119416665741b709cd686fca87 (diff)
Rename to criterion
Diffstat (limited to 'text_recognizer/criterion/label_smoothing.py')
-rw-r--r--text_recognizer/criterion/label_smoothing.py50
1 files changed, 50 insertions, 0 deletions
diff --git a/text_recognizer/criterion/label_smoothing.py b/text_recognizer/criterion/label_smoothing.py
new file mode 100644
index 0000000..5b3a47e
--- /dev/null
+++ b/text_recognizer/criterion/label_smoothing.py
@@ -0,0 +1,50 @@
+"""Implementations of custom loss functions."""
+import torch
+from torch import nn
+from torch import Tensor
+
+
+class LabelSmoothingLoss(nn.Module):
+ r"""Loss functions for making networks less over confident.
+
+ It is used to calibrate the network so that the predicted probabilities
+ reflect the accuracy. The function is given by:
+
+ L = (1 - \alpha) * y_hot + \alpha / K
+
+ This means that some of the probability mass is transferred to the incorrect
+ labels, thus not forcing the network try to put all probability mass into
+ one label, and this works as a regulizer.
+ """
+
+ def __init__(
+ self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1
+ ) -> None:
+ super().__init__()
+ if not 0.0 < smoothing < 1.0:
+ raise ValueError("Smoothing must be between 0.0 and 1.0")
+ self.ignore_index = ignore_index
+ self.confidence = 1.0 - smoothing
+ self.smoothing = smoothing
+ self.dim = dim
+
+ def forward(self, output: Tensor, target: Tensor) -> Tensor:
+ """Computes the loss.
+
+ Args:
+ output (Tensor): outputictions from the network.
+ target (Tensor): Ground truth.
+
+ Shapes:
+ TBC
+
+ Returns:
+ (Tensor): Label smoothing loss.
+ """
+ output = output.log_softmax(dim=self.dim)
+ with torch.no_grad():
+ true_dist = torch.zeros_like(output)
+ true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
+ true_dist.masked_fill_((target == 4).unsqueeze(1), 0)
+ true_dist += self.smoothing / output.size(self.dim)
+ return torch.mean(torch.sum(-true_dist * output, dim=self.dim))