diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-27 22:41:39 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-10-27 22:41:39 +0200 |
commit | bec4aafe707be8e5763ad6b2194d4589f20594a9 (patch) | |
tree | 506517ca6a17241a305114e787d1b899a48a3d86 /text_recognizer/criterions/ctc.py | |
parent | 9a8044f4a3826a119416665741b709cd686fca87 (diff) |
Rename to criterion
Diffstat (limited to 'text_recognizer/criterions/ctc.py')
-rw-r--r-- | text_recognizer/criterions/ctc.py | 38 |
1 files changed, 0 insertions, 38 deletions
diff --git a/text_recognizer/criterions/ctc.py b/text_recognizer/criterions/ctc.py deleted file mode 100644 index 42a0b25..0000000 --- a/text_recognizer/criterions/ctc.py +++ /dev/null @@ -1,38 +0,0 @@ -"""CTC loss.""" -import torch -from torch import LongTensor, nn, Tensor -import torch.nn.functional as F - - -class CTCLoss(nn.Module): - """CTC loss.""" - - def __init__(self, blank: int) -> None: - super().__init__() - self.blank = blank - - def forward(self, outputs: Tensor, targets: Tensor) -> Tensor: - """Computes the CTC loss.""" - device = outputs.device - - log_probs = F.log_softmax(outputs, dim=2).permute(1, 0, 2) - output_lengths = LongTensor([outputs.shape[1]] * outputs.shape[0]).to(device) - - targets_ = LongTensor([]).to(device) - target_lengths = LongTensor([]).to(device) - for target in targets: - # Remove padding - target = target[target != self.blank].to(device) - targets_ = torch.cat((targets_, target)) - target_lengths = torch.cat( - (target_lengths, torch.LongTensor([len(target)]).to(device)), dim=0 - ) - - return F.ctc_loss( - log_probs, - targets, - output_lengths, - target_lengths, - blank=self.blank, - zero_infinity=True, - ) |