summaryrefslogtreecommitdiff
path: root/text_recognizer/criterions/ctc.py
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-27 22:41:39 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-10-27 22:41:39 +0200
commitbec4aafe707be8e5763ad6b2194d4589f20594a9 (patch)
tree506517ca6a17241a305114e787d1b899a48a3d86 /text_recognizer/criterions/ctc.py
parent9a8044f4a3826a119416665741b709cd686fca87 (diff)
Rename to criterion
Diffstat (limited to 'text_recognizer/criterions/ctc.py')
-rw-r--r--text_recognizer/criterions/ctc.py38
1 files changed, 0 insertions, 38 deletions
diff --git a/text_recognizer/criterions/ctc.py b/text_recognizer/criterions/ctc.py
deleted file mode 100644
index 42a0b25..0000000
--- a/text_recognizer/criterions/ctc.py
+++ /dev/null
@@ -1,38 +0,0 @@
-"""CTC loss."""
-import torch
-from torch import LongTensor, nn, Tensor
-import torch.nn.functional as F
-
-
-class CTCLoss(nn.Module):
- """CTC loss."""
-
- def __init__(self, blank: int) -> None:
- super().__init__()
- self.blank = blank
-
- def forward(self, outputs: Tensor, targets: Tensor) -> Tensor:
- """Computes the CTC loss."""
- device = outputs.device
-
- log_probs = F.log_softmax(outputs, dim=2).permute(1, 0, 2)
- output_lengths = LongTensor([outputs.shape[1]] * outputs.shape[0]).to(device)
-
- targets_ = LongTensor([]).to(device)
- target_lengths = LongTensor([]).to(device)
- for target in targets:
- # Remove padding
- target = target[target != self.blank].to(device)
- targets_ = torch.cat((targets_, target))
- target_lengths = torch.cat(
- (target_lengths, torch.LongTensor([len(target)]).to(device)), dim=0
- )
-
- return F.ctc_loss(
- log_probs,
- targets,
- output_lengths,
- target_lengths,
- blank=self.blank,
- zero_infinity=True,
- )