summaryrefslogtreecommitdiff
path: root/text_recognizer/criterions
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-05 23:05:25 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-07-05 23:05:25 +0200
commit4d1f2cef39688871d2caafce42a09316381a27ae (patch)
tree0f4385969e7df6d7d313cd5910bde9a7475ca027 /text_recognizer/criterions
parentf0481decdad9afb52494e9e95996deef843ef233 (diff)
Refactor with attr, working on cnn+transformer network
Diffstat (limited to 'text_recognizer/criterions')
-rw-r--r--text_recognizer/criterions/__init__.py1
-rw-r--r--text_recognizer/criterions/label_smoothing_loss.py42
2 files changed, 43 insertions, 0 deletions
diff --git a/text_recognizer/criterions/__init__.py b/text_recognizer/criterions/__init__.py
new file mode 100644
index 0000000..5b0a7ab
--- /dev/null
+++ b/text_recognizer/criterions/__init__.py
@@ -0,0 +1 @@
+"""Module with custom loss functions."""
diff --git a/text_recognizer/criterions/label_smoothing_loss.py b/text_recognizer/criterions/label_smoothing_loss.py
new file mode 100644
index 0000000..40a7609
--- /dev/null
+++ b/text_recognizer/criterions/label_smoothing_loss.py
@@ -0,0 +1,42 @@
+"""Implementations of custom loss functions."""
+import torch
+from torch import nn
+from torch import Tensor
+import torch.nn.functional as F
+
+
+class LabelSmoothingLoss(nn.Module):
+ """Label smoothing cross entropy loss."""
+
+ def __init__(
+ self, label_smoothing: float, vocab_size: int, ignore_index: int = -100
+ ) -> None:
+ assert 0.0 < label_smoothing <= 1.0
+ self.ignore_index = ignore_index
+ super().__init__()
+
+ smoothing_value = label_smoothing / (vocab_size - 2)
+ one_hot = torch.full((vocab_size,), smoothing_value)
+ one_hot[self.ignore_index] = 0
+ self.register_buffer("one_hot", one_hot.unsqueeze(0))
+
+ self.confidence = 1.0 - label_smoothing
+
+ def forward(self, output: Tensor, targets: Tensor) -> Tensor:
+ """Computes the loss.
+
+ Args:
+ output (Tensor): Predictions from the network.
+ targets (Tensor): Ground truth.
+
+ Shapes:
+ outpus: Batch size x num classes
+ targets: Batch size
+
+ Returns:
+ Tensor: Label smoothing loss.
+ """
+ model_prob = self.one_hot.repeat(targets.size(0), 1)
+ model_prob.scatter_(1, targets.unsqueeze(1), self.confidence)
+ model_prob.masked_fill_((targets == self.ignore_index).unsqueeze(1), 0)
+ return F.kl_div(output, model_prob, reduction="sum")