From d8899b2c97046e49807ba2a440ee2dda6e2db335 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 19 Sep 2021 21:02:18 +0200 Subject: Add more detailed description of label smoothing --- text_recognizer/criterions/label_smoothing.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/text_recognizer/criterions/label_smoothing.py b/text_recognizer/criterions/label_smoothing.py index cc71c45..74ff145 100644 --- a/text_recognizer/criterions/label_smoothing.py +++ b/text_recognizer/criterions/label_smoothing.py @@ -2,11 +2,24 @@ import torch from torch import nn from torch import Tensor -import torch.nn.functional as F class LabelSmoothingLoss(nn.Module): - def __init__(self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1): + r"""Loss functions for making networks less over confident. + + It is used to calibrate the network so that the predicted probabilities + reflect the accuracy. The function is given by: + + L = (1 - \alpha) * y_hot + \alpha / K + + This means that some of the probability mass is transferred to the incorrect + labels, thus not forcing the network try to put all probability mass into + one label, and this works as a regulizer. + """ + + def __init__( + self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1 + ) -> None: super().__init__() assert 0.0 < smoothing <= 1.0 self.ignore_index = ignore_index @@ -19,13 +32,13 @@ class LabelSmoothingLoss(nn.Module): Args: output (Tensor): outputictions from the network. - targets (Tensor): Ground truth. + target (Tensor): Ground truth. Shapes: TBC Returns: - Tensor: Label smoothing loss. + (Tensor): Label smoothing loss. """ output = output.log_softmax(dim=self.dim) with torch.no_grad(): -- cgit v1.2.3-70-g09d2