diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-25 22:32:17 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-25 22:32:17 +0200 |
commit | ecf13ba9fc048ad81e8e21f0f9d68eb132605c39 (patch) | |
tree | c50652a2a7cda687344b77e9789df31f4fba66d0 /text_recognizer/criterions | |
parent | 41c3e99fe57874ba1855c893bf47087d474ec6b8 (diff) |
Add ctc loss
Diffstat (limited to 'text_recognizer/criterions')
-rw-r--r-- | text_recognizer/criterions/ctc.py | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/text_recognizer/criterions/ctc.py b/text_recognizer/criterions/ctc.py new file mode 100644 index 0000000..42a0b25 --- /dev/null +++ b/text_recognizer/criterions/ctc.py @@ -0,0 +1,38 @@ +"""CTC loss.""" +import torch +from torch import LongTensor, nn, Tensor +import torch.nn.functional as F + + +class CTCLoss(nn.Module): + """CTC loss.""" + + def __init__(self, blank: int) -> None: + super().__init__() + self.blank = blank + + def forward(self, outputs: Tensor, targets: Tensor) -> Tensor: + """Computes the CTC loss.""" + device = outputs.device + + log_probs = F.log_softmax(outputs, dim=2).permute(1, 0, 2) + output_lengths = LongTensor([outputs.shape[1]] * outputs.shape[0]).to(device) + + targets_ = LongTensor([]).to(device) + target_lengths = LongTensor([]).to(device) + for target in targets: + # Remove padding + target = target[target != self.blank].to(device) + targets_ = torch.cat((targets_, target)) + target_lengths = torch.cat( + (target_lengths, torch.LongTensor([len(target)]).to(device)), dim=0 + ) + + return F.ctc_loss( + log_probs, + targets, + output_lengths, + target_lengths, + blank=self.blank, + zero_infinity=True, + ) |