diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-21 21:30:29 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-11-21 21:30:29 +0100 |
commit | 61398b45ed6e2501036d4f3e4115a825035b0f91 (patch) | |
tree | 011fc62205391050747fb67c22832e0d5b06e9c6 /text_recognizer | |
parent | c82c1e84497e99ab89e162786c09a642e09bd504 (diff) |
Remove label smoothing loss
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/criterion/label_smoothing.py | 50 |
1 files changed, 0 insertions, 50 deletions
diff --git a/text_recognizer/criterion/label_smoothing.py b/text_recognizer/criterion/label_smoothing.py deleted file mode 100644 index 5b3a47e..0000000 --- a/text_recognizer/criterion/label_smoothing.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Implementations of custom loss functions.""" -import torch -from torch import nn -from torch import Tensor - - -class LabelSmoothingLoss(nn.Module): - r"""Loss functions for making networks less over confident. - - It is used to calibrate the network so that the predicted probabilities - reflect the accuracy. The function is given by: - - L = (1 - \alpha) * y_hot + \alpha / K - - This means that some of the probability mass is transferred to the incorrect - labels, thus not forcing the network try to put all probability mass into - one label, and this works as a regulizer. - """ - - def __init__( - self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1 - ) -> None: - super().__init__() - if not 0.0 < smoothing < 1.0: - raise ValueError("Smoothing must be between 0.0 and 1.0") - self.ignore_index = ignore_index - self.confidence = 1.0 - smoothing - self.smoothing = smoothing - self.dim = dim - - def forward(self, output: Tensor, target: Tensor) -> Tensor: - """Computes the loss. - - Args: - output (Tensor): outputictions from the network. - target (Tensor): Ground truth. - - Shapes: - TBC - - Returns: - (Tensor): Label smoothing loss. - """ - output = output.log_softmax(dim=self.dim) - with torch.no_grad(): - true_dist = torch.zeros_like(output) - true_dist.scatter_(1, target.unsqueeze(1), self.confidence) - true_dist.masked_fill_((target == 4).unsqueeze(1), 0) - true_dist += self.smoothing / output.size(self.dim) - return torch.mean(torch.sum(-true_dist * output, dim=self.dim)) |