diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-19 21:02:18 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-09-19 21:02:18 +0200 |
commit | d8899b2c97046e49807ba2a440ee2dda6e2db335 (patch) | |
tree | 4f2d65ac96fb5778e62dbc1bdcaa7d07c721b831 /text_recognizer/criterions/label_smoothing.py | |
parent | f9ede1e61008ead9b7abe910dff79067cf862312 (diff) |
Add more detailed description of label smoothing
Diffstat (limited to 'text_recognizer/criterions/label_smoothing.py')
-rw-r--r-- | text_recognizer/criterions/label_smoothing.py | 21 |
1 files changed, 17 insertions, 4 deletions
diff --git a/text_recognizer/criterions/label_smoothing.py b/text_recognizer/criterions/label_smoothing.py index cc71c45..74ff145 100644 --- a/text_recognizer/criterions/label_smoothing.py +++ b/text_recognizer/criterions/label_smoothing.py @@ -2,11 +2,24 @@ import torch from torch import nn from torch import Tensor -import torch.nn.functional as F class LabelSmoothingLoss(nn.Module): - def __init__(self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1): + r"""Loss functions for making networks less over confident. + + It is used to calibrate the network so that the predicted probabilities + reflect the accuracy. The function is given by: + + L = (1 - \alpha) * y_hot + \alpha / K + + This means that some of the probability mass is transferred to the incorrect + labels, thus not forcing the network try to put all probability mass into + one label, and this works as a regulizer. + """ + + def __init__( + self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1 + ) -> None: super().__init__() assert 0.0 < smoothing <= 1.0 self.ignore_index = ignore_index @@ -19,13 +32,13 @@ class LabelSmoothingLoss(nn.Module): Args: output (Tensor): outputictions from the network. - targets (Tensor): Ground truth. + target (Tensor): Ground truth. Shapes: TBC Returns: - Tensor: Label smoothing loss. + (Tensor): Label smoothing loss. """ output = output.log_softmax(dim=self.dim) with torch.no_grad(): |