diff options
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/model/__init__.py | 1 | ||||
-rw-r--r-- | text_recognizer/model/base.py (renamed from text_recognizer/models/base.py) | 5 | ||||
-rw-r--r-- | text_recognizer/model/greedy_decoder.py | 58 | ||||
-rw-r--r-- | text_recognizer/model/transformer.py (renamed from text_recognizer/models/transformer.py) | 24 | ||||
-rw-r--r-- | text_recognizer/models/__init__.py | 2 | ||||
-rw-r--r-- | text_recognizer/models/greedy_decoder.py | 51 |
6 files changed, 68 insertions, 73 deletions
diff --git a/text_recognizer/model/__init__.py b/text_recognizer/model/__init__.py new file mode 100644 index 0000000..1982daf --- /dev/null +++ b/text_recognizer/model/__init__.py @@ -0,0 +1 @@ +"""PyTorch Lightning models modules.""" diff --git a/text_recognizer/models/base.py b/text_recognizer/model/base.py index 4dd5cdf..1cff796 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/model/base.py @@ -5,14 +5,14 @@ import hydra import torch from loguru import logger as log from omegaconf import DictConfig -from pytorch_lightning import LightningModule +import pytorch_lightning as L from torch import nn, Tensor from torchmetrics import Accuracy from text_recognizer.data.tokenizer import Tokenizer -class LitBase(LightningModule): +class LitBase(L.LightningModule): """Abstract PyTorch Lightning class.""" def __init__( @@ -41,7 +41,6 @@ class LitBase(LightningModule): epoch: int, batch_idx: int, optimizer: Type[torch.optim.Optimizer], - optimizer_idx: int, ) -> None: """Optimal way to set grads to zero.""" optimizer.zero_grad(set_to_none=True) diff --git a/text_recognizer/model/greedy_decoder.py b/text_recognizer/model/greedy_decoder.py new file mode 100644 index 0000000..5cbbb66 --- /dev/null +++ b/text_recognizer/model/greedy_decoder.py @@ -0,0 +1,58 @@ +"""Greedy decoder.""" +from typing import Type +from text_recognizer.data.tokenizer import Tokenizer +import torch +from torch import nn, Tensor + + +class GreedyDecoder: + def __init__( + self, + network: Type[nn.Module], + tokenizer: Tokenizer, + max_output_len: int = 682, + ) -> None: + self.network = network + self.start_index = tokenizer.start_index + self.end_index = tokenizer.end_index + self.pad_index = tokenizer.pad_index + self.max_output_len = max_output_len + + def __call__(self, x: Tensor) -> Tensor: + bsz = x.shape[0] + + # Encode image(s) to latent vectors. + img_features = self.network.encode(x) + + # Create a placeholder matrix for storing outputs from the network + indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) + indecies[:, 0] = self.start_index + + try: + for i in range(1, self.max_output_len): + tokens = indecies[:, :i] # (B, Sy) + logits = self.network.decode(tokens, img_features) # (B, C, Sy) + indecies_ = torch.argmax(logits, dim=1) # (B, Sy) + indecies[:, i : i + 1] = indecies_[:, -1:] + + # Early stopping of prediction loop if token is end or padding token. + if ( + (indecies[:, i - 1] == self.end_index) + | (indecies[:, i - 1] == self.pad_index) + ).all(): + break + + # Set all tokens after end token to pad token. + for i in range(1, self.max_output_len): + idx = (indecies[:, i - 1] == self.end_index) | ( + indecies[:, i - 1] == self.pad_index + ) + indecies[idx, i] = self.pad_index + + return indecies + except Exception: + # TODO: investigate this error more + print(x.shape) + # print(indecies) + print(indecies.shape) + print(img_features.shape) diff --git a/text_recognizer/models/transformer.py b/text_recognizer/model/transformer.py index bbfaac0..23b2a3a 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/model/transformer.py @@ -1,6 +1,6 @@ -"""Lightning model for base Transformers.""" +"""Lightning model for transformer networks.""" from typing import Callable, Optional, Sequence, Tuple, Type -from text_recognizer.models.greedy_decoder import GreedyDecoder +from text_recognizer.model.greedy_decoder import GreedyDecoder import torch from omegaconf import DictConfig @@ -8,12 +8,10 @@ from torch import nn, Tensor from torchmetrics import CharErrorRate, WordErrorRate from text_recognizer.data.tokenizer import Tokenizer -from text_recognizer.models.base import LitBase +from text_recognizer.model.base import LitBase class LitTransformer(LitBase): - """A PyTorch Lightning model for transformer networks.""" - def __init__( self, network: Type[nn.Module], @@ -51,22 +49,18 @@ class LitTransformer(LitBase): data, targets = batch logits = self.teacher_forward(data, targets[:, :-1]) loss = self.loss_fn(logits, targets[:, 1:]) - self.log("train/loss", loss) + self.log("train/loss", loss, prog_bar=True) return loss def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Validation step.""" data, targets = batch - - logits = self.teacher_forward(data, targets[:, :-1]) - loss = self.loss_fn(logits, targets[:, 1:]) - preds = self.predict(data) + preds = self(data) pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets) self.val_acc(preds, targets) self.val_cer(pred_text, target_text) self.val_wer(pred_text, target_text) - self.log("val/loss", loss, on_step=False, on_epoch=True) self.log("val/acc", self.val_acc, on_step=False, on_epoch=True) self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) self.log("val/wer", self.val_wer, on_step=False, on_epoch=True, prog_bar=True) @@ -74,25 +68,21 @@ class LitTransformer(LitBase): def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" data, targets = batch - - logits = self.teacher_forward(data, targets[:, :-1]) - loss = self.loss_fn(logits, targets[:, 1:]) preds = self(data) pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets) self.test_acc(preds, targets) self.test_cer(pred_text, target_text) self.test_wer(pred_text, target_text) - self.log("test/loss", loss, on_step=False, on_epoch=True) self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True) def _to_tokens( self, - indecies: Tensor, + indices: Tensor, ) -> Sequence[str]: - return [self.tokenizer.decode(i) for i in indecies] + return [self.tokenizer.decode(i) for i in indices] @torch.no_grad() def predict(self, x: Tensor) -> Tensor: diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py deleted file mode 100644 index cc02487..0000000 --- a/text_recognizer/models/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -"""PyTorch Lightning models modules.""" -from text_recognizer.models.transformer import LitTransformer diff --git a/text_recognizer/models/greedy_decoder.py b/text_recognizer/models/greedy_decoder.py deleted file mode 100644 index 9d2f192..0000000 --- a/text_recognizer/models/greedy_decoder.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Greedy decoder.""" -from typing import Type -from text_recognizer.data.tokenizer import Tokenizer -import torch -from torch import nn, Tensor - - -class GreedyDecoder: - def __init__( - self, - network: Type[nn.Module], - tokenizer: Tokenizer, - max_output_len: int = 682, - ) -> None: - self.network = network - self.start_index = tokenizer.start_index - self.end_index = tokenizer.end_index - self.pad_index = tokenizer.pad_index - self.max_output_len = max_output_len - - def __call__(self, x: Tensor) -> Tensor: - bsz = x.shape[0] - - # Encode image(s) to latent vectors. - img_features = self.network.encode(x) - - # Create a placeholder matrix for storing outputs from the network - indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) - indecies[:, 0] = self.start_index - - for Sy in range(1, self.max_output_len): - tokens = indecies[:, :Sy] # (B, Sy) - logits = self.network.decode(tokens, img_features) # (B, C, Sy) - indecies_ = torch.argmax(logits, dim=1) # (B, Sy) - indecies[:, Sy : Sy + 1] = indecies_[:, -1:] - - # Early stopping of prediction loop if token is end or padding token. - if ( - (indecies[:, Sy - 1] == self.end_index) - | (indecies[:, Sy - 1] == self.pad_index) - ).all(): - break - - # Set all tokens after end token to pad token. - for Sy in range(1, self.max_output_len): - idx = (indecies[:, Sy - 1] == self.end_index) | ( - indecies[:, Sy - 1] == self.pad_index - ) - indecies[idx, Sy] = self.pad_index - - return indecies |