summaryrefslogtreecommitdiff
path: root/text_recognizer/criterions
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/criterions')
-rw-r--r--text_recognizer/criterions/label_smoothing.py21
1 files changed, 17 insertions, 4 deletions
diff --git a/text_recognizer/criterions/label_smoothing.py b/text_recognizer/criterions/label_smoothing.py
index cc71c45..74ff145 100644
--- a/text_recognizer/criterions/label_smoothing.py
+++ b/text_recognizer/criterions/label_smoothing.py
@@ -2,11 +2,24 @@
import torch
from torch import nn
from torch import Tensor
-import torch.nn.functional as F
class LabelSmoothingLoss(nn.Module):
- def __init__(self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1):
+ r"""Loss functions for making networks less over confident.
+
+ It is used to calibrate the network so that the predicted probabilities
+ reflect the accuracy. The function is given by:
+
+ L = (1 - \alpha) * y_hot + \alpha / K
+
+ This means that some of the probability mass is transferred to the incorrect
+ labels, thus not forcing the network try to put all probability mass into
+ one label, and this works as a regulizer.
+ """
+
+ def __init__(
+ self, ignore_index: int = -100, smoothing: float = 0.0, dim: int = -1
+ ) -> None:
super().__init__()
assert 0.0 < smoothing <= 1.0
self.ignore_index = ignore_index
@@ -19,13 +32,13 @@ class LabelSmoothingLoss(nn.Module):
Args:
output (Tensor): outputictions from the network.
- targets (Tensor): Ground truth.
+ target (Tensor): Ground truth.
Shapes:
TBC
Returns:
- Tensor: Label smoothing loss.
+ (Tensor): Label smoothing loss.
"""
output = output.log_softmax(dim=self.dim)
with torch.no_grad():