summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-21 23:19:59 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-21 23:19:59 +0100
commitca8ac871e124120ade3669bdaa69a5acc746ed40 (patch)
tree84a892c40d9f76fc7c322b73d7c05950a781b5da
parentb02768eea1c39601c5fb90bf125d8dadbc07fb1c (diff)
Remove criterion and ctc loss
-rw-r--r--text_recognizer/criterion/__init__.py1
-rw-r--r--text_recognizer/criterion/ctc.py38
2 files changed, 0 insertions, 39 deletions
diff --git a/text_recognizer/criterion/__init__.py b/text_recognizer/criterion/__init__.py
deleted file mode 100644
index 5b0a7ab..0000000
--- a/text_recognizer/criterion/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-"""Module with custom loss functions."""
diff --git a/text_recognizer/criterion/ctc.py b/text_recognizer/criterion/ctc.py
deleted file mode 100644
index 42a0b25..0000000
--- a/text_recognizer/criterion/ctc.py
+++ /dev/null
@@ -1,38 +0,0 @@
-"""CTC loss."""
-import torch
-from torch import LongTensor, nn, Tensor
-import torch.nn.functional as F
-
-
-class CTCLoss(nn.Module):
- """CTC loss."""
-
- def __init__(self, blank: int) -> None:
- super().__init__()
- self.blank = blank
-
- def forward(self, outputs: Tensor, targets: Tensor) -> Tensor:
- """Computes the CTC loss."""
- device = outputs.device
-
- log_probs = F.log_softmax(outputs, dim=2).permute(1, 0, 2)
- output_lengths = LongTensor([outputs.shape[1]] * outputs.shape[0]).to(device)
-
- targets_ = LongTensor([]).to(device)
- target_lengths = LongTensor([]).to(device)
- for target in targets:
- # Remove padding
- target = target[target != self.blank].to(device)
- targets_ = torch.cat((targets_, target))
- target_lengths = torch.cat(
- (target_lengths, torch.LongTensor([len(target)]).to(device)), dim=0
- )
-
- return F.ctc_loss(
- log_probs,
- targets,
- output_lengths,
- target_lengths,
- blank=self.blank,
- zero_infinity=True,
- )