From ca8ac871e124120ade3669bdaa69a5acc746ed40 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sun, 21 Nov 2021 23:19:59 +0100 Subject: Remove criterion and ctc loss --- text_recognizer/criterion/__init__.py | 1 - text_recognizer/criterion/ctc.py | 38 ----------------------------------- 2 files changed, 39 deletions(-) delete mode 100644 text_recognizer/criterion/__init__.py delete mode 100644 text_recognizer/criterion/ctc.py (limited to 'text_recognizer/criterion') diff --git a/text_recognizer/criterion/__init__.py b/text_recognizer/criterion/__init__.py deleted file mode 100644 index 5b0a7ab..0000000 --- a/text_recognizer/criterion/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Module with custom loss functions.""" diff --git a/text_recognizer/criterion/ctc.py b/text_recognizer/criterion/ctc.py deleted file mode 100644 index 42a0b25..0000000 --- a/text_recognizer/criterion/ctc.py +++ /dev/null @@ -1,38 +0,0 @@ -"""CTC loss.""" -import torch -from torch import LongTensor, nn, Tensor -import torch.nn.functional as F - - -class CTCLoss(nn.Module): - """CTC loss.""" - - def __init__(self, blank: int) -> None: - super().__init__() - self.blank = blank - - def forward(self, outputs: Tensor, targets: Tensor) -> Tensor: - """Computes the CTC loss.""" - device = outputs.device - - log_probs = F.log_softmax(outputs, dim=2).permute(1, 0, 2) - output_lengths = LongTensor([outputs.shape[1]] * outputs.shape[0]).to(device) - - targets_ = LongTensor([]).to(device) - target_lengths = LongTensor([]).to(device) - for target in targets: - # Remove padding - target = target[target != self.blank].to(device) - targets_ = torch.cat((targets_, target)) - target_lengths = torch.cat( - (target_lengths, torch.LongTensor([len(target)]).to(device)), dim=0 - ) - - return F.ctc_loss( - log_probs, - targets, - output_lengths, - target_lengths, - blank=self.blank, - zero_infinity=True, - ) -- cgit v1.2.3-70-g09d2