summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/loss/loss.py
blob: d12dc9c0105cbd8110cae2874f8ebdb62dfd963d (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
"""Implementations of custom loss functions."""
import torch
from torch import nn
from torch import Tensor

__all__ = ["LabelSmoothingCrossEntropy"]


class LabelSmoothingCrossEntropy(nn.Module):
    """Label smoothing loss function."""

    def __init__(
        self,
        classes: int,
        smoothing: float = 0.0,
        ignore_index: int = None,
        dim: int = -1,
    ) -> None:
        super().__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.ignore_index = ignore_index
        self.cls = classes
        self.dim = dim

    def forward(self, pred: Tensor, target: Tensor) -> Tensor:
        """Calculates the loss."""
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            # true_dist = pred.data.clone()
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
            if self.ignore_index is not None:
                true_dist[:, self.ignore_index] = 0
                mask = torch.nonzero(target == self.ignore_index, as_tuple=False)
                if mask.dim() > 0:
                    true_dist.index_fill_(0, mask.squeeze(), 0.0)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))