From a2a3133ed5da283888efbdb9924d0e3733c274c8 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 9 May 2021 18:50:55 +0200 Subject: tranformer layer done --- text_recognizer/networks/loss/loss.py | 39 ----------------------------------- 1 file changed, 39 deletions(-) delete mode 100644 text_recognizer/networks/loss/loss.py (limited to 'text_recognizer/networks/loss/loss.py') diff --git a/text_recognizer/networks/loss/loss.py b/text_recognizer/networks/loss/loss.py deleted file mode 100644 index d12dc9c..0000000 --- a/text_recognizer/networks/loss/loss.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Implementations of custom loss functions.""" -import torch -from torch import nn -from torch import Tensor - -__all__ = ["LabelSmoothingCrossEntropy"] - - -class LabelSmoothingCrossEntropy(nn.Module): - """Label smoothing loss function.""" - - def __init__( - self, - classes: int, - smoothing: float = 0.0, - ignore_index: int = None, - dim: int = -1, - ) -> None: - super().__init__() - self.confidence = 1.0 - smoothing - self.smoothing = smoothing - self.ignore_index = ignore_index - self.cls = classes - self.dim = dim - - def forward(self, pred: Tensor, target: Tensor) -> Tensor: - """Calculates the loss.""" - pred = pred.log_softmax(dim=self.dim) - with torch.no_grad(): - # true_dist = pred.data.clone() - true_dist = torch.zeros_like(pred) - true_dist.fill_(self.smoothing / (self.cls - 1)) - true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) - if self.ignore_index is not None: - true_dist[:, self.ignore_index] = 0 - mask = torch.nonzero(target == self.ignore_index, as_tuple=False) - if mask.dim() > 0: - true_dist.index_fill_(0, mask.squeeze(), 0.0) - return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) -- cgit v1.2.3-70-g09d2