summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-10-04 22:08:38 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-10-04 22:08:38 +0200
commit8c4a0c2603975cfc63f4e4019386e001387c42c9 (patch)
tree213dac666a8ac71e7a48608ee492b80572a23584
parent3ec10be9141bf71fb10d699b31a66b4e5046973c (diff)
Add greedy decoder
-rw-r--r--README.md2
-rw-r--r--text_recognizer/models/greedy_decoder.py51
-rw-r--r--text_recognizer/models/transformer.py65
-rw-r--r--training/conf/config.yaml1
-rw-r--r--training/run.py10
5 files changed, 74 insertions, 55 deletions
diff --git a/README.md b/README.md
index f549cda..6f95001 100644
--- a/README.md
+++ b/README.md
@@ -73,7 +73,7 @@ Ideas of mine that did not work unfortunately:
- [ ] Evaluation
- [ ] Wandb artifact fetcher
- [ ] fix linting
-- [ ] Modularize the decoder
+- [x] Modularize the decoder
- [ ] Add kv cache
- [x] Fix stems
- [x] residual attn
diff --git a/text_recognizer/models/greedy_decoder.py b/text_recognizer/models/greedy_decoder.py
new file mode 100644
index 0000000..9d2f192
--- /dev/null
+++ b/text_recognizer/models/greedy_decoder.py
@@ -0,0 +1,51 @@
+"""Greedy decoder."""
+from typing import Type
+from text_recognizer.data.tokenizer import Tokenizer
+import torch
+from torch import nn, Tensor
+
+
+class GreedyDecoder:
+ def __init__(
+ self,
+ network: Type[nn.Module],
+ tokenizer: Tokenizer,
+ max_output_len: int = 682,
+ ) -> None:
+ self.network = network
+ self.start_index = tokenizer.start_index
+ self.end_index = tokenizer.end_index
+ self.pad_index = tokenizer.pad_index
+ self.max_output_len = max_output_len
+
+ def __call__(self, x: Tensor) -> Tensor:
+ bsz = x.shape[0]
+
+ # Encode image(s) to latent vectors.
+ img_features = self.network.encode(x)
+
+ # Create a placeholder matrix for storing outputs from the network
+ indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
+ indecies[:, 0] = self.start_index
+
+ for Sy in range(1, self.max_output_len):
+ tokens = indecies[:, :Sy] # (B, Sy)
+ logits = self.network.decode(tokens, img_features) # (B, C, Sy)
+ indecies_ = torch.argmax(logits, dim=1) # (B, Sy)
+ indecies[:, Sy : Sy + 1] = indecies_[:, -1:]
+
+ # Early stopping of prediction loop if token is end or padding token.
+ if (
+ (indecies[:, Sy - 1] == self.end_index)
+ | (indecies[:, Sy - 1] == self.pad_index)
+ ).all():
+ break
+
+ # Set all tokens after end token to pad token.
+ for Sy in range(1, self.max_output_len):
+ idx = (indecies[:, Sy - 1] == self.end_index) | (
+ indecies[:, Sy - 1] == self.pad_index
+ )
+ indecies[idx, Sy] = self.pad_index
+
+ return indecies
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index 6048901..bbfaac0 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,6 +1,6 @@
"""Lightning model for base Transformers."""
-from collections.abc import Sequence
-from typing import Optional, Tuple, Type
+from typing import Callable, Optional, Sequence, Tuple, Type
+from text_recognizer.models.greedy_decoder import GreedyDecoder
import torch
from omegaconf import DictConfig
@@ -20,6 +20,7 @@ class LitTransformer(LitBase):
loss_fn: Type[nn.Module],
optimizer_config: DictConfig,
tokenizer: Tokenizer,
+ decoder: Callable = GreedyDecoder,
lr_scheduler_config: Optional[DictConfig] = None,
max_output_len: int = 682,
) -> None:
@@ -35,9 +36,10 @@ class LitTransformer(LitBase):
self.test_cer = CharErrorRate()
self.val_wer = WordErrorRate()
self.test_wer = WordErrorRate()
+ self.decoder = decoder
def forward(self, data: Tensor) -> Tensor:
- """Forward pass with the transformer network."""
+ """Autoregressive forward pass."""
return self.predict(data)
def teacher_forward(self, data: Tensor, targets: Tensor) -> Tensor:
@@ -59,7 +61,7 @@ class LitTransformer(LitBase):
logits = self.teacher_forward(data, targets[:, :-1])
loss = self.loss_fn(logits, targets[:, 1:])
preds = self.predict(data)
- pred_text, target_text = self._get_text(preds), self._get_text(targets)
+ pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets)
self.val_acc(preds, targets)
self.val_cer(pred_text, target_text)
@@ -76,7 +78,7 @@ class LitTransformer(LitBase):
logits = self.teacher_forward(data, targets[:, :-1])
loss = self.loss_fn(logits, targets[:, 1:])
preds = self(data)
- pred_text, target_text = self._get_text(preds), self._get_text(targets)
+ pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets)
self.test_acc(preds, targets)
self.test_cer(pred_text, target_text)
@@ -86,55 +88,12 @@ class LitTransformer(LitBase):
self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)
self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True)
- def _get_text(
+ def _to_tokens(
self,
- xs: Tensor,
- ) -> Tuple[Sequence[str], Sequence[str]]:
- return [self.tokenizer.decode(x) for x in xs]
+ indecies: Tensor,
+ ) -> Sequence[str]:
+ return [self.tokenizer.decode(i) for i in indecies]
@torch.no_grad()
def predict(self, x: Tensor) -> Tensor:
- """Predicts text in image.
-
- Args:
- x (Tensor): Image(s) to extract text from.
-
- Shapes:
- - x: :math: `(B, H, W)`
- - output: :math: `(B, S)`
-
- Returns:
- Tensor: A tensor of token indices of the predictions from the model.
- """
- start_index = self.tokenizer.start_index
- end_index = self.tokenizer.end_index
- pad_index = self.tokenizer.pad_index
- bsz = x.shape[0]
-
- # Encode image(s) to latent vectors.
- img_features = self.network.encode(x)
-
- # Create a placeholder matrix for storing outputs from the network
- indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
- indecies[:, 0] = start_index
-
- for Sy in range(1, self.max_output_len):
- tokens = indecies[:, :Sy] # (B, Sy)
- logits = self.network.decode(tokens, img_features) # (B, C, Sy)
- indecies_ = torch.argmax(logits, dim=1) # (B, Sy)
- indecies[:, Sy : Sy + 1] = indecies_[:, -1:]
-
- # Early stopping of prediction loop if token is end or padding token.
- if (
- (indecies[:, Sy - 1] == end_index) | (indecies[:, Sy - 1] == pad_index)
- ).all():
- break
-
- # Set all tokens after end token to pad token.
- for Sy in range(1, self.max_output_len):
- idx = (indecies[:, Sy - 1] == end_index) | (
- indecies[:, Sy - 1] == pad_index
- )
- indecies[idx, Sy] = pad_index
-
- return indecies
+ return self.decoder(x)
diff --git a/training/conf/config.yaml b/training/conf/config.yaml
index a95307f..e57a8a8 100644
--- a/training/conf/config.yaml
+++ b/training/conf/config.yaml
@@ -4,6 +4,7 @@ defaults:
- _self_
- callbacks: default
- criterion: cross_entropy
+ - decoder: greedy
- datamodule: iam_extended_paragraphs
- hydra: default
- logger: wandb
diff --git a/training/run.py b/training/run.py
index 429b1a2..288a1ef 100644
--- a/training/run.py
+++ b/training/run.py
@@ -1,5 +1,5 @@
"""Script to run experiments."""
-from typing import List, Optional, Type
+from typing import Callable, List, Optional, Type
import hydra
from loguru import logger as log
@@ -34,11 +34,19 @@ def run(config: DictConfig) -> Optional[float]:
log.info(f"Instantiating criterion <{config.criterion._target_}>")
loss_fn: Type[nn.Module] = hydra.utils.instantiate(config.criterion)
+ log.info(f"Instantiating decoder <{config.criterion._target_}>")
+ decoder: Type[Callable] = hydra.utils.instantiate(
+ config.decoder,
+ network=network,
+ tokenizer=datamodule.tokenizer,
+ )
+
log.info(f"Instantiating model <{config.model._target_}>")
model: LightningModule = hydra.utils.instantiate(
config.model,
network=network,
tokenizer=datamodule.tokenizer,
+ decoder=decoder,
loss_fn=loss_fn,
optimizer_config=config.optimizer,
lr_scheduler_config=config.lr_scheduler,