diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-27 22:41:39 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-27 22:41:39 +0200 |
commit | bec4aafe707be8e5763ad6b2194d4589f20594a9 (patch) | |
tree | 506517ca6a17241a305114e787d1b899a48a3d86 /text_recognizer/criterion/label_smoothing.py | |
parent | 9a8044f4a3826a119416665741b709cd686fca87 (diff) |
Rename to criterion
Diffstat (limited to 'text_recognizer/criterion/label_smoothing.py')
-rw-r--r-- | text_recognizer/criterion/label_smoothing.py | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/text_recognizer/criterion/label_smoothing.py b/text_recognizer/criterion/label_smoothing.py new file mode 100644 index 0000000..5b3a47e --- /dev/null +++ b/text_recognizer/criterion/label_smoothing.py @@ -0,0 +1,50 @@ +"""Implementations of custom loss functions.""" +import torch +from torch import nn +from torch import Tensor + + +class LabelSmoothingLoss(nn.Module): + r"""Loss functions for making networks less over confident. + + It is used to calibrate the network so that the predicted probabilities + reflect the accuracy. The function is given by: + + L = (1 - \alpha) * y_hot + \alpha / K + + This means that some of the probability mass is transferred to the incorrect + labels, thus not forcing the network try to put all probability mass into + one label, and this works as a regulizer. + """ + + def __init__( + self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1 + ) -> None: + super().__init__() + if not 0.0 < smoothing < 1.0: + raise ValueError("Smoothing must be between 0.0 and 1.0") + self.ignore_index = ignore_index + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.dim = dim + + def forward(self, output: Tensor, target: Tensor) -> Tensor: + """Computes the loss. + + Args: + output (Tensor): outputictions from the network. + target (Tensor): Ground truth. + + Shapes: + TBC + + Returns: + (Tensor): Label smoothing loss. + """ + output = output.log_softmax(dim=self.dim) + with torch.no_grad(): + true_dist = torch.zeros_like(output) + true_dist.scatter_(1, target.unsqueeze(1), self.confidence) + true_dist.masked_fill_((target == 4).unsqueeze(1), 0) + true_dist += self.smoothing / output.size(self.dim) + return torch.mean(torch.sum(-true_dist * output, dim=self.dim)) |