diff options
Diffstat (limited to 'text_recognizer/networks/loss')
-rw-r--r-- | text_recognizer/networks/loss/label_smoothing_loss.py | 42 | ||||
-rw-r--r-- | text_recognizer/networks/loss/loss.py | 39 |
2 files changed, 42 insertions, 39 deletions
diff --git a/text_recognizer/networks/loss/label_smoothing_loss.py b/text_recognizer/networks/loss/label_smoothing_loss.py new file mode 100644 index 0000000..40a7609 --- /dev/null +++ b/text_recognizer/networks/loss/label_smoothing_loss.py @@ -0,0 +1,42 @@ +"""Implementations of custom loss functions.""" +import torch +from torch import nn +from torch import Tensor +import torch.nn.functional as F + + +class LabelSmoothingLoss(nn.Module): + """Label smoothing cross entropy loss.""" + + def __init__( + self, label_smoothing: float, vocab_size: int, ignore_index: int = -100 + ) -> None: + assert 0.0 < label_smoothing <= 1.0 + self.ignore_index = ignore_index + super().__init__() + + smoothing_value = label_smoothing / (vocab_size - 2) + one_hot = torch.full((vocab_size,), smoothing_value) + one_hot[self.ignore_index] = 0 + self.register_buffer("one_hot", one_hot.unsqueeze(0)) + + self.confidence = 1.0 - label_smoothing + + def forward(self, output: Tensor, targets: Tensor) -> Tensor: + """Computes the loss. + + Args: + output (Tensor): Predictions from the network. + targets (Tensor): Ground truth. + + Shapes: + outpus: Batch size x num classes + targets: Batch size + + Returns: + Tensor: Label smoothing loss. + """ + model_prob = self.one_hot.repeat(targets.size(0), 1) + model_prob.scatter_(1, targets.unsqueeze(1), self.confidence) + model_prob.masked_fill_((targets == self.ignore_index).unsqueeze(1), 0) + return F.kl_div(output, model_prob, reduction="sum") diff --git a/text_recognizer/networks/loss/loss.py b/text_recognizer/networks/loss/loss.py deleted file mode 100644 index d12dc9c..0000000 --- a/text_recognizer/networks/loss/loss.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Implementations of custom loss functions.""" -import torch -from torch import nn -from torch import Tensor - -__all__ = ["LabelSmoothingCrossEntropy"] - - -class LabelSmoothingCrossEntropy(nn.Module): - """Label smoothing loss function.""" - - def __init__( - self, - classes: int, - smoothing: float = 0.0, - ignore_index: int = None, - dim: int = -1, - ) -> None: - super().__init__() - self.confidence = 1.0 - smoothing - self.smoothing = smoothing - self.ignore_index = ignore_index - self.cls = classes - self.dim = dim - - def forward(self, pred: Tensor, target: Tensor) -> Tensor: - """Calculates the loss.""" - pred = pred.log_softmax(dim=self.dim) - with torch.no_grad(): - # true_dist = pred.data.clone() - true_dist = torch.zeros_like(pred) - true_dist.fill_(self.smoothing / (self.cls - 1)) - true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) - if self.ignore_index is not None: - true_dist[:, self.ignore_index] = 0 - mask = torch.nonzero(target == self.ignore_index, as_tuple=False) - if mask.dim() > 0: - true_dist.index_fill_(0, mask.squeeze(), 0.0) - return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) |