summaryrefslogtreecommitdiff
path: root/text_recognizer/criterions/label_smoothing.py
blob: cc71c45cc6102d2b9f8ae326a6cbe0222aa72b8e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
"""Implementations of custom loss functions."""
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F


class LabelSmoothingLoss(nn.Module):
    def __init__(self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1):
        super().__init__()
        assert 0.0 < smoothing <= 1.0
        self.ignore_index = ignore_index
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.dim = dim

    def forward(self, output: Tensor, target: Tensor) -> Tensor:
        """Computes the loss.

        Args:
            output (Tensor): outputictions from the network.
            targets (Tensor): Ground truth.

        Shapes:
            TBC

        Returns:
            Tensor: Label smoothing loss.
        """
        output = output.log_softmax(dim=self.dim)
        with torch.no_grad():
            true_dist = torch.zeros_like(output)
            true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
            true_dist.masked_fill_((target == 4).unsqueeze(1), 0)
            true_dist += self.smoothing / output.size(self.dim)
        return torch.mean(torch.sum(-true_dist * output, dim=self.dim))