summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/loss.py
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 12:41:04 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2020-11-08 12:41:04 +0100
commitbeeaef529e7c893a3475fe27edc880e283373725 (patch)
tree59eb72562bf7a5a9470c2586e6280600ad94f1ae /src/text_recognizer/networks/loss.py
parent4d7713746eb936832e84852e90292936b933e87d (diff)
Trying to get the CNNTransformer to work, but it is hard.
Diffstat (limited to 'src/text_recognizer/networks/loss.py')
-rw-r--r--src/text_recognizer/networks/loss.py39
1 files changed, 37 insertions, 2 deletions
diff --git a/src/text_recognizer/networks/loss.py b/src/text_recognizer/networks/loss.py
index ff843cf..cf9fa0d 100644
--- a/src/text_recognizer/networks/loss.py
+++ b/src/text_recognizer/networks/loss.py
@@ -1,10 +1,12 @@
"""Implementations of custom loss functions."""
from pytorch_metric_learning import distances, losses, miners, reducers
+import torch
from torch import nn
from torch import Tensor
+from torch.autograd import Variable
+import torch.nn.functional as F
-
-__all__ = ["EmbeddingLoss"]
+__all__ = ["EmbeddingLoss", "LabelSmoothingCrossEntropy"]
class EmbeddingLoss:
@@ -32,3 +34,36 @@ class EmbeddingLoss:
hard_pairs = self.miner(embeddings, labels)
loss = self.loss_fn(embeddings, labels, hard_pairs)
return loss
+
+
+class LabelSmoothingCrossEntropy(nn.Module):
+ """Label smoothing loss function."""
+
+ def __init__(
+ self,
+ classes: int,
+ smoothing: float = 0.0,
+ ignore_index: int = None,
+ dim: int = -1,
+ ) -> None:
+ super().__init__()
+ self.confidence = 1.0 - smoothing
+ self.smoothing = smoothing
+ self.ignore_index = ignore_index
+ self.cls = classes
+ self.dim = dim
+
+ def forward(self, pred: Tensor, target: Tensor) -> Tensor:
+ """Calculates the loss."""
+ pred = pred.log_softmax(dim=self.dim)
+ with torch.no_grad():
+ # true_dist = pred.data.clone()
+ true_dist = torch.zeros_like(pred)
+ true_dist.fill_(self.smoothing / (self.cls - 1))
+ true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
+ if self.ignore_index is not None:
+ true_dist[:, self.ignore_index] = 0
+ mask = torch.nonzero(target == self.ignore_index, as_tuple=False)
+ if mask.dim() > 0:
+ true_dist.index_fill_(0, mask.squeeze(), 0.0)
+ return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))