From 61398b45ed6e2501036d4f3e4115a825035b0f91 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 21 Nov 2021 21:30:29 +0100 Subject: Remove label smoothing loss --- text_recognizer/criterion/label_smoothing.py | 50 ---------------------------- 1 file changed, 50 deletions(-) delete mode 100644 text_recognizer/criterion/label_smoothing.py diff --git a/text_recognizer/criterion/label_smoothing.py b/text_recognizer/criterion/label_smoothing.py deleted file mode 100644 index 5b3a47e..0000000 --- a/text_recognizer/criterion/label_smoothing.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Implementations of custom loss functions.""" -import torch -from torch import nn -from torch import Tensor - - -class LabelSmoothingLoss(nn.Module): - r"""Loss functions for making networks less over confident. - - It is used to calibrate the network so that the predicted probabilities - reflect the accuracy. The function is given by: - - L = (1 - \alpha) * y_hot + \alpha / K - - This means that some of the probability mass is transferred to the incorrect - labels, thus not forcing the network try to put all probability mass into - one label, and this works as a regulizer. - """ - - def __init__( - self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1 - ) -> None: - super().__init__() - if not 0.0 < smoothing < 1.0: - raise ValueError("Smoothing must be between 0.0 and 1.0") - self.ignore_index = ignore_index - self.confidence = 1.0 - smoothing - self.smoothing = smoothing - self.dim = dim - - def forward(self, output: Tensor, target: Tensor) -> Tensor: - """Computes the loss. - - Args: - output (Tensor): outputictions from the network. - target (Tensor): Ground truth. - - Shapes: - TBC - - Returns: - (Tensor): Label smoothing loss. - """ - output = output.log_softmax(dim=self.dim) - with torch.no_grad(): - true_dist = torch.zeros_like(output) - true_dist.scatter_(1, target.unsqueeze(1), self.confidence) - true_dist.masked_fill_((target == 4).unsqueeze(1), 0) - true_dist += self.smoothing / output.size(self.dim) - return torch.mean(torch.sum(-true_dist * output, dim=self.dim)) -- cgit v1.2.3-70-g09d2