diff options
-rw-r--r-- | text_recognizer/models/conformer.py | 27 |
1 files changed, 13 insertions, 14 deletions
diff --git a/text_recognizer/models/conformer.py b/text_recognizer/models/conformer.py index ee3d1e3..487eabe 100644 --- a/text_recognizer/models/conformer.py +++ b/text_recognizer/models/conformer.py @@ -51,19 +51,18 @@ class LitConformer(LitBase): """Predicts a sequence of characters.""" logits = self(x) logprobs = torch.log_softmax(logits, dim=1) - pred = self.decode(logprobs, self.max_output_len)[0] - return "".join([self.mapping[i] for i in pred if i not in self.ignore_indices]) + return self.decode(logprobs, self.max_output_len) def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: """Training step.""" data, targets = batch logits = self(data) logprobs = torch.log_softmax(logits, dim=1) - B, _, S = logprobs.shape - input_length = torch.ones(B).types_as(logprobs).int() * S - target_length = first_element(targets, self.pad_index).types_as(targets) + B, S, _ = logprobs.shape + input_length = torch.ones(B).type_as(logprobs).int() * S + target_length = first_element(targets, self.pad_index).type_as(targets) loss = self.loss_fn( - logprobs.permute(2, 0, 1), targets, input_length, target_length + logprobs.permute(1, 0, 2), targets, input_length, target_length ) self.log("train/loss", loss) return loss @@ -73,11 +72,11 @@ class LitConformer(LitBase): data, targets = batch logits = self(data) logprobs = torch.log_softmax(logits, dim=1) - B, _, S = logprobs.shape - input_length = torch.ones(B).types_as(logprobs).int() * S - target_length = first_element(targets, self.pad_index).types_as(targets) + B, S, _ = logprobs.shape + input_length = torch.ones(B).type_as(logprobs).int() * S + target_length = first_element(targets, self.pad_index).type_as(targets) loss = self.loss_fn( - logprobs.permute(2, 0, 1), targets, input_length, target_length + logprobs.permute(1, 0, 2), targets, input_length, target_length ) self.log("val/loss", loss) preds = self.decode(logprobs, targets.shape[1]) @@ -105,15 +104,15 @@ class LitConformer(LitBase): max_length (int): Max length of a sequence. Shapes: - - x: :math: `(B, C, Y)` - - output: :math: `(B, S)` + - x: :math: `(B, T, C)` + - output: :math: `(B, T)` Returns: Tensor: A predicted sequence of characters. """ B = logprobs.shape[0] - argmax = logprobs.argmax(1) - decoded = torch.ones((B, max_length)).types_as(logprobs).int() * self.pad_index + argmax = logprobs.argmax(2) + decoded = torch.ones((B, max_length)).type_as(logprobs).int() * self.pad_index for i in range(B): seq = [ b |