diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-08 12:41:04 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2020-11-08 12:41:04 +0100 |
commit | beeaef529e7c893a3475fe27edc880e283373725 (patch) | |
tree | 59eb72562bf7a5a9470c2586e6280600ad94f1ae /src/text_recognizer/networks/loss.py | |
parent | 4d7713746eb936832e84852e90292936b933e87d (diff) |
Trying to get the CNNTransformer to work, but it is hard.
Diffstat (limited to 'src/text_recognizer/networks/loss.py')
-rw-r--r-- | src/text_recognizer/networks/loss.py | 39 |
1 files changed, 37 insertions, 2 deletions
diff --git a/src/text_recognizer/networks/loss.py b/src/text_recognizer/networks/loss.py index ff843cf..cf9fa0d 100644 --- a/src/text_recognizer/networks/loss.py +++ b/src/text_recognizer/networks/loss.py @@ -1,10 +1,12 @@ """Implementations of custom loss functions.""" from pytorch_metric_learning import distances, losses, miners, reducers +import torch from torch import nn from torch import Tensor +from torch.autograd import Variable +import torch.nn.functional as F - -__all__ = ["EmbeddingLoss"] +__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"] class EmbeddingLoss: @@ -32,3 +34,36 @@ class EmbeddingLoss: hard_pairs = self.miner(embeddings, labels) loss = self.loss_fn(embeddings, labels, hard_pairs) return loss + + +class LabelSmoothingCrossEntropy(nn.Module): + """Label smoothing loss function.""" + + def __init__( + self, + classes: int, + smoothing: float = 0.0, + ignore_index: int = None, + dim: int = -1, + ) -> None: + super().__init__() + self.confidence = 1.0 - smoothing + self.smoothing = smoothing + self.ignore_index = ignore_index + self.cls = classes + self.dim = dim + + def forward(self, pred: Tensor, target: Tensor) -> Tensor: + """Calculates the loss.""" + pred = pred.log_softmax(dim=self.dim) + with torch.no_grad(): + # true_dist = pred.data.clone() + true_dist = torch.zeros_like(pred) + true_dist.fill_(self.smoothing / (self.cls - 1)) + true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence) + if self.ignore_index is not None: + true_dist[:, self.ignore_index] = 0 + mask = torch.nonzero(target == self.ignore_index, as_tuple=False) + if mask.dim() > 0: + true_dist.index_fill_(0, mask.squeeze(), 0.0) + return torch.mean(torch.sum(-true_dist * pred, dim=self.dim)) |