summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-25 22:32:17 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-25 22:32:17 +0200
commitecf13ba9fc048ad81e8e21f0f9d68eb132605c39 (patch)
treec50652a2a7cda687344b77e9789df31f4fba66d0
parent41c3e99fe57874ba1855c893bf47087d474ec6b8 (diff)
Add ctc loss
-rw-r--r--text_recognizer/criterions/ctc.py38
1 files changed, 38 insertions, 0 deletions
diff --git a/text_recognizer/criterions/ctc.py b/text_recognizer/criterions/ctc.py
new file mode 100644
index 0000000..42a0b25
--- /dev/null
+++ b/text_recognizer/criterions/ctc.py
@@ -0,0 +1,38 @@
+"""CTC loss."""
+import torch
+from torch import LongTensor, nn, Tensor
+import torch.nn.functional as F
+
+
+class CTCLoss(nn.Module):
+ """CTC loss."""
+
+ def __init__(self, blank: int) -> None:
+ super().__init__()
+ self.blank = blank
+
+ def forward(self, outputs: Tensor, targets: Tensor) -> Tensor:
+ """Computes the CTC loss."""
+ device = outputs.device
+
+ log_probs = F.log_softmax(outputs, dim=2).permute(1, 0, 2)
+ output_lengths = LongTensor([outputs.shape[1]] * outputs.shape[0]).to(device)
+
+ targets_ = LongTensor([]).to(device)
+ target_lengths = LongTensor([]).to(device)
+ for target in targets:
+ # Remove padding
+ target = target[target != self.blank].to(device)
+ targets_ = torch.cat((targets_, target))
+ target_lengths = torch.cat(
+ (target_lengths, torch.LongTensor([len(target)]).to(device)), dim=0
+ )
+
+ return F.ctc_loss(
+ log_probs,
+ targets,
+ output_lengths,
+ target_lengths,
+ blank=self.blank,
+ zero_infinity=True,
+ )