summaryrefslogtreecommitdiff
path: root/text_recognizer/criterion/label_smoothing.py
blob: 5b3a47ea6af44312ed47f67cf398e48caf72e83c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
"""Implementations of custom loss functions."""
import torch
from torch import nn
from torch import Tensor


class LabelSmoothingLoss(nn.Module):
    r"""Loss functions for making networks less over confident.

    It is used to calibrate the network so that the predicted probabilities
    reflect the accuracy. The function is given by:

        L = (1 - \alpha) * y_hot + \alpha / K

    This means that some of the probability mass is transferred to the incorrect
    labels, thus not forcing the network try to put all probability mass into
    one label, and this works as a regulizer.
    """

    def __init__(
        self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1
    ) -> None:
        super().__init__()
        if not 0.0 < smoothing < 1.0:
            raise ValueError("Smoothing must be between 0.0 and 1.0")
        self.ignore_index = ignore_index
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.dim = dim

    def forward(self, output: Tensor, target: Tensor) -> Tensor:
        """Computes the loss.

        Args:
            output (Tensor): outputictions from the network.
            target (Tensor): Ground truth.

        Shapes:
            TBC

        Returns:
            (Tensor): Label smoothing loss.
        """
        output = output.log_softmax(dim=self.dim)
        with torch.no_grad():
            true_dist = torch.zeros_like(output)
            true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
            true_dist.masked_fill_((target == 4).unsqueeze(1), 0)
            true_dist += self.smoothing / output.size(self.dim)
        return torch.mean(torch.sum(-true_dist * output, dim=self.dim))