summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/models/conformer.py27
1 files changed, 13 insertions, 14 deletions
diff --git a/text_recognizer/models/conformer.py b/text_recognizer/models/conformer.py
index ee3d1e3..487eabe 100644
--- a/text_recognizer/models/conformer.py
+++ b/text_recognizer/models/conformer.py
@@ -51,19 +51,18 @@ class LitConformer(LitBase):
"""Predicts a sequence of characters."""
logits = self(x)
logprobs = torch.log_softmax(logits, dim=1)
- pred = self.decode(logprobs, self.max_output_len)[0]
- return "".join([self.mapping[i] for i in pred if i not in self.ignore_indices])
+ return self.decode(logprobs, self.max_output_len)
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
"""Training step."""
data, targets = batch
logits = self(data)
logprobs = torch.log_softmax(logits, dim=1)
- B, _, S = logprobs.shape
- input_length = torch.ones(B).types_as(logprobs).int() * S
- target_length = first_element(targets, self.pad_index).types_as(targets)
+ B, S, _ = logprobs.shape
+ input_length = torch.ones(B).type_as(logprobs).int() * S
+ target_length = first_element(targets, self.pad_index).type_as(targets)
loss = self.loss_fn(
- logprobs.permute(2, 0, 1), targets, input_length, target_length
+ logprobs.permute(1, 0, 2), targets, input_length, target_length
)
self.log("train/loss", loss)
return loss
@@ -73,11 +72,11 @@ class LitConformer(LitBase):
data, targets = batch
logits = self(data)
logprobs = torch.log_softmax(logits, dim=1)
- B, _, S = logprobs.shape
- input_length = torch.ones(B).types_as(logprobs).int() * S
- target_length = first_element(targets, self.pad_index).types_as(targets)
+ B, S, _ = logprobs.shape
+ input_length = torch.ones(B).type_as(logprobs).int() * S
+ target_length = first_element(targets, self.pad_index).type_as(targets)
loss = self.loss_fn(
- logprobs.permute(2, 0, 1), targets, input_length, target_length
+ logprobs.permute(1, 0, 2), targets, input_length, target_length
)
self.log("val/loss", loss)
preds = self.decode(logprobs, targets.shape[1])
@@ -105,15 +104,15 @@ class LitConformer(LitBase):
max_length (int): Max length of a sequence.
Shapes:
- - x: :math: `(B, C, Y)`
- - output: :math: `(B, S)`
+ - x: :math: `(B, T, C)`
+ - output: :math: `(B, T)`
Returns:
Tensor: A predicted sequence of characters.
"""
B = logprobs.shape[0]
- argmax = logprobs.argmax(1)
- decoded = torch.ones((B, max_length)).types_as(logprobs).int() * self.pad_index
+ argmax = logprobs.argmax(2)
+ decoded = torch.ones((B, max_length)).type_as(logprobs).int() * self.pad_index
for i in range(B):
seq = [
b