summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/loss/label_smoothing_loss.py
blob: 40a7609d8937ad6f808a116475188e05c323953d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
"""Implementations of custom loss functions."""
import torch
from torch import nn
from torch import Tensor
import torch.nn.functional as F


class LabelSmoothingLoss(nn.Module):
    """Label smoothing cross entropy loss."""

    def __init__(
        self, label_smoothing: float, vocab_size: int, ignore_index: int = -100
    ) -> None:
        assert 0.0 < label_smoothing <= 1.0
        self.ignore_index = ignore_index
        super().__init__()

        smoothing_value = label_smoothing / (vocab_size - 2)
        one_hot = torch.full((vocab_size,), smoothing_value)
        one_hot[self.ignore_index] = 0
        self.register_buffer("one_hot", one_hot.unsqueeze(0))

        self.confidence = 1.0 - label_smoothing

    def forward(self, output: Tensor, targets: Tensor) -> Tensor:
        """Computes the loss.

        Args:
            output (Tensor): Predictions from the network.
            targets (Tensor): Ground truth.

        Shapes:
            outpus: Batch size x num classes
            targets: Batch size

        Returns:
            Tensor: Label smoothing loss.
        """
        model_prob = self.one_hot.repeat(targets.size(0), 1)
        model_prob.scatter_(1, targets.unsqueeze(1), self.confidence)
        model_prob.masked_fill_((targets == self.ignore_index).unsqueeze(1), 0)
        return F.kl_div(output, model_prob, reduction="sum")