summaryrefslogtreecommitdiff
path: root/text_recognizer/models/transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models/transformer.py')
-rw-r--r--text_recognizer/models/transformer.py65
1 files changed, 12 insertions, 53 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 6048901..bbfaac0 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,6 +1,6 @@
"""Lightning model for base Transformers."""
-from collections.abc import Sequence
-from typing import Optional, Tuple, Type
+from typing import Callable, Optional, Sequence, Tuple, Type
+from text_recognizer.models.greedy_decoder import GreedyDecoder
import torch
from omegaconf import DictConfig
@@ -20,6 +20,7 @@ class LitTransformer(LitBase):
loss_fn: Type[nn.Module],
optimizer_config: DictConfig,
tokenizer: Tokenizer,
+ decoder: Callable = GreedyDecoder,
lr_scheduler_config: Optional[DictConfig] = None,
max_output_len: int = 682,
) -> None:
@@ -35,9 +36,10 @@ class LitTransformer(LitBase):
self.test_cer = CharErrorRate()
self.val_wer = WordErrorRate()
self.test_wer = WordErrorRate()
+ self.decoder = decoder
def forward(self, data: Tensor) -> Tensor:
- """Forward pass with the transformer network."""
+ """Autoregressive forward pass."""
return self.predict(data)
def teacher_forward(self, data: Tensor, targets: Tensor) -> Tensor:
@@ -59,7 +61,7 @@ class LitTransformer(LitBase):
logits = self.teacher_forward(data, targets[:, :-1])
loss = self.loss_fn(logits, targets[:, 1:])
preds = self.predict(data)
- pred_text, target_text = self._get_text(preds), self._get_text(targets)
+ pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets)
self.val_acc(preds, targets)
self.val_cer(pred_text, target_text)
@@ -76,7 +78,7 @@ class LitTransformer(LitBase):
logits = self.teacher_forward(data, targets[:, :-1])
loss = self.loss_fn(logits, targets[:, 1:])
preds = self(data)
- pred_text, target_text = self._get_text(preds), self._get_text(targets)
+ pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets)
self.test_acc(preds, targets)
self.test_cer(pred_text, target_text)
@@ -86,55 +88,12 @@ class LitTransformer(LitBase):
self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)
self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True)
- def _get_text(
+ def _to_tokens(
self,
- xs: Tensor,
- ) -> Tuple[Sequence[str], Sequence[str]]:
- return [self.tokenizer.decode(x) for x in xs]
+ indecies: Tensor,
+ ) -> Sequence[str]:
+ return [self.tokenizer.decode(i) for i in indecies]
@torch.no_grad()
def predict(self, x: Tensor) -> Tensor:
- """Predicts text in image.
-
- Args:
- x (Tensor): Image(s) to extract text from.
-
- Shapes:
- - x: :math: `(B, H, W)`
- - output: :math: `(B, S)`
-
- Returns:
- Tensor: A tensor of token indices of the predictions from the model.
- """
- start_index = self.tokenizer.start_index
- end_index = self.tokenizer.end_index
- pad_index = self.tokenizer.pad_index
- bsz = x.shape[0]
-
- # Encode image(s) to latent vectors.
- img_features = self.network.encode(x)
-
- # Create a placeholder matrix for storing outputs from the network
- indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
- indecies[:, 0] = start_index
-
- for Sy in range(1, self.max_output_len):
- tokens = indecies[:, :Sy] # (B, Sy)
- logits = self.network.decode(tokens, img_features) # (B, C, Sy)
- indecies_ = torch.argmax(logits, dim=1) # (B, Sy)
- indecies[:, Sy : Sy + 1] = indecies_[:, -1:]
-
- # Early stopping of prediction loop if token is end or padding token.
- if (
- (indecies[:, Sy - 1] == end_index) | (indecies[:, Sy - 1] == pad_index)
- ).all():
- break
-
- # Set all tokens after end token to pad token.
- for Sy in range(1, self.max_output_len):
- idx = (indecies[:, Sy - 1] == end_index) | (
- indecies[:, Sy - 1] == pad_index
- )
- indecies[idx, Sy] = pad_index
-
- return indecies
+ return self.decoder(x)