diff options
Diffstat (limited to 'text_recognizer/criterions')
| -rw-r--r-- | text_recognizer/criterions/label_smoothing.py | 21 | 
1 files changed, 17 insertions, 4 deletions
| diff --git a/text_recognizer/criterions/label_smoothing.py b/text_recognizer/criterions/label_smoothing.py index cc71c45..74ff145 100644 --- a/text_recognizer/criterions/label_smoothing.py +++ b/text_recognizer/criterions/label_smoothing.py @@ -2,11 +2,24 @@  import torch  from torch import nn  from torch import Tensor -import torch.nn.functional as F  class LabelSmoothingLoss(nn.Module): -    def __init__(self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1): +    r"""Loss functions for making networks less over confident. + +    It is used to calibrate the network so that the predicted probabilities +    reflect the accuracy. The function is given by: + +        L = (1 - \alpha) * y_hot + \alpha / K + +    This means that some of the probability mass is transferred to the incorrect +    labels, thus not forcing the network try to put all probability mass into +    one label, and this works as a regulizer. +    """ + +    def __init__( +        self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1 +    ) -> None:          super().__init__()          assert 0.0 < smoothing <= 1.0          self.ignore_index = ignore_index @@ -19,13 +32,13 @@ class LabelSmoothingLoss(nn.Module):          Args:              output (Tensor): outputictions from the network. -            targets (Tensor): Ground truth. +            target (Tensor): Ground truth.          Shapes:              TBC          Returns: -            Tensor: Label smoothing loss. +            (Tensor): Label smoothing loss.          """          output = output.log_softmax(dim=self.dim)          with torch.no_grad(): |