summaryrefslogtreecommitdiff
path: root/text_recognizer/criterions/ctc.py
blob: 42a0b2506f4a159b5aa49b4181ce1db55cb25c75 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""CTC loss."""
import torch
from torch import LongTensor, nn, Tensor
import torch.nn.functional as F


class CTCLoss(nn.Module):
    """CTC loss."""

    def __init__(self, blank: int) -> None:
        super().__init__()
        self.blank = blank

    def forward(self, outputs: Tensor, targets: Tensor) -> Tensor:
        """Computes the CTC loss."""
        device = outputs.device

        log_probs = F.log_softmax(outputs, dim=2).permute(1, 0, 2)
        output_lengths = LongTensor([outputs.shape[1]] * outputs.shape[0]).to(device)

        targets_ = LongTensor([]).to(device)
        target_lengths = LongTensor([]).to(device)
        for target in targets:
            # Remove padding
            target = target[target != self.blank].to(device)
            targets_ = torch.cat((targets_, target))
            target_lengths = torch.cat(
                (target_lengths, torch.LongTensor([len(target)]).to(device)), dim=0
            )

        return F.ctc_loss(
            log_probs,
            targets,
            output_lengths,
            target_lengths,
            blank=self.blank,
            zero_infinity=True,
        )