summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/loss.py
blob: cf9fa0d9309907ad139b9728c7488ecd7f2d5252 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""Implementations of custom loss functions."""
from pytorch_metric_learning import distances, losses, miners, reducers
import torch
from torch import nn
from torch import Tensor
from torch.autograd import Variable
import torch.nn.functional as F

__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"]


class EmbeddingLoss:
    """Metric loss for training encoders to produce information-rich latent embeddings."""

    def __init__(self, margin: float = 0.2, type_of_triplets: str = "semihard") -> None:
        self.distance = distances.CosineSimilarity()
        self.reducer = reducers.ThresholdReducer(low=0)
        self.loss_fn = losses.TripletMarginLoss(
            margin=margin, distance=self.distance, reducer=self.reducer
        )
        self.miner = miners.MultiSimilarityMiner(epsilon=margin, distance=self.distance)

    def __call__(self, embeddings: Tensor, labels: Tensor) -> Tensor:
        """Computes the metric loss for the embeddings based on their labels.

        Args:
            embeddings (Tensor): The laten vectors encoded by the network.
            labels (Tensor): Labels of the embeddings.

        Returns:
            Tensor: The metric loss for the embeddings.

        """
        hard_pairs = self.miner(embeddings, labels)
        loss = self.loss_fn(embeddings, labels, hard_pairs)
        return loss


class LabelSmoothingCrossEntropy(nn.Module):
    """Label smoothing loss function."""

    def __init__(
        self,
        classes: int,
        smoothing: float = 0.0,
        ignore_index: int = None,
        dim: int = -1,
    ) -> None:
        super().__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.ignore_index = ignore_index
        self.cls = classes
        self.dim = dim

    def forward(self, pred: Tensor, target: Tensor) -> Tensor:
        """Calculates the loss."""
        pred = pred.log_softmax(dim=self.dim)
        with torch.no_grad():
            # true_dist = pred.data.clone()
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing / (self.cls - 1))
            true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
            if self.ignore_index is not None:
                true_dist[:, self.ignore_index] = 0
                mask = torch.nonzero(target == self.ignore_index, as_tuple=False)
                if mask.dim() > 0:
                    true_dist.index_fill_(0, mask.squeeze(), 0.0)
        return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))