summaryrefslogtreecommitdiff
path: root/text_recognizer/criterions
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/criterions')
-rw-r--r--text_recognizer/criterions/ctc.py38
1 files changed, 38 insertions, 0 deletions
diff --git a/text_recognizer/criterions/ctc.py b/text_recognizer/criterions/ctc.py
new file mode 100644
index 0000000..42a0b25
--- /dev/null
+++ b/text_recognizer/criterions/ctc.py
@@ -0,0 +1,38 @@
+"""CTC loss."""
+import torch
+from torch import LongTensor, nn, Tensor
+import torch.nn.functional as F
+
+
+class CTCLoss(nn.Module):
+ """CTC loss."""
+
+ def __init__(self, blank: int) -> None:
+ super().__init__()
+ self.blank = blank
+
+ def forward(self, outputs: Tensor, targets: Tensor) -> Tensor:
+ """Computes the CTC loss."""
+ device = outputs.device
+
+ log_probs = F.log_softmax(outputs, dim=2).permute(1, 0, 2)
+ output_lengths = LongTensor([outputs.shape[1]] * outputs.shape[0]).to(device)
+
+ targets_ = LongTensor([]).to(device)
+ target_lengths = LongTensor([]).to(device)
+ for target in targets:
+ # Remove padding
+ target = target[target != self.blank].to(device)
+ targets_ = torch.cat((targets_, target))
+ target_lengths = torch.cat(
+ (target_lengths, torch.LongTensor([len(target)]).to(device)), dim=0
+ )
+
+ return F.ctc_loss(
+ log_probs,
+ targets,
+ output_lengths,
+ target_lengths,
+ blank=self.blank,
+ zero_infinity=True,
+ )