diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-30 23:15:03 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-07-30 23:15:03 +0200 |
commit | 7268035fb9e57342612a8cc50a1fe04e8841ca2f (patch) | |
tree | 8d4cf3743975bd25f2c04d6a56ff3d4608a7e8d9 /text_recognizer/models | |
parent | 92fc1c7ed2f9f64552be8f71d9b8ab0d5a0a88d4 (diff) |
attr bug fix, properly loading network
Diffstat (limited to 'text_recognizer/models')
-rw-r--r-- | text_recognizer/models/base.py | 4 | ||||
-rw-r--r-- | text_recognizer/models/metrics.py | 8 | ||||
-rw-r--r-- | text_recognizer/models/transformer.py | 80 | ||||
-rw-r--r-- | text_recognizer/models/vqvae.py | 5 |
4 files changed, 76 insertions, 21 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py index 3e02261..dfb4ca4 100644 --- a/text_recognizer/models/base.py +++ b/text_recognizer/models/base.py @@ -11,8 +11,6 @@ from torch import nn from torch import Tensor import torchmetrics -from text_recognizer.networks.base import BaseNetwork - @attr.s class BaseLitModel(LightningModule): @@ -21,7 +19,7 @@ class BaseLitModel(LightningModule): def __attrs_pre_init__(self) -> None: super().__init__() - network: Type[BaseNetwork] = attr.ib() + network: Type[nn.Module] = attr.ib() criterion_config: DictConfig = attr.ib(converter=DictConfig) optimizer_config: DictConfig = attr.ib(converter=DictConfig) lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig) diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py index 4117ae2..9793157 100644 --- a/text_recognizer/models/metrics.py +++ b/text_recognizer/models/metrics.py @@ -1,5 +1,5 @@ """Character Error Rate (CER).""" -from typing import Set, Sequence +from typing import Set import attr import editdistance @@ -12,7 +12,7 @@ from torchmetrics import Metric class CharacterErrorRate(Metric): """Character error rate metric, computed using Levenshtein distance.""" - ignore_tokens: Set = attr.ib(converter=set) + ignore_indices: Set = attr.ib(converter=set) error: Tensor = attr.ib(init=False) total: Tensor = attr.ib(init=False) @@ -25,8 +25,8 @@ class CharacterErrorRate(Metric): """Update CER.""" bsz = preds.shape[0] for index in range(bsz): - pred = [p for p in preds[index].tolist() if p not in self.ignore_tokens] - target = [t for t in targets[index].tolist() if t not in self.ignore_tokens] + pred = [p for p in preds[index].tolist() if p not in self.ignore_indices] + target = [t for t in targets[index].tolist() if t not in self.ignore_indices] distance = editdistance.distance(pred, target) error = distance / max(len(pred), len(target)) self.error += error diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index f5cb491..7a9d566 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -1,11 +1,11 @@ """PyTorch Lightning model for base Transformers.""" -from typing import Dict, List, Optional, Sequence, Union, Tuple, Type +from typing import Sequence, Tuple, Type import attr -import hydra -from omegaconf import DictConfig -from torch import nn, Tensor +import torch +from torch import Tensor +from text_recognizer.data.mappings import AbstractMapping from text_recognizer.models.metrics import CharacterErrorRate from text_recognizer.models.base import BaseLitModel @@ -13,18 +13,31 @@ from text_recognizer.models.base import BaseLitModel @attr.s(auto_attribs=True) class TransformerLitModel(BaseLitModel): """A PyTorch Lightning model for transformer networks.""" + mapping: Type[AbstractMapping] = attr.ib() + start_token: str = attr.ib() + end_token: str = attr.ib() + pad_token: str = attr.ib() - ignore_tokens: Sequence[str] = attr.ib(default=("<s>", "<e>", "<p>",)) + start_index: Tensor = attr.ib(init=False) + end_index: Tensor = attr.ib(init=False) + pad_index: Tensor = attr.ib(init=False) + + ignore_indices: Sequence[str] = attr.ib(init=False) val_cer: CharacterErrorRate = attr.ib(init=False) test_cer: CharacterErrorRate = attr.ib(init=False) def __attrs_post_init__(self) -> None: - self.val_cer = CharacterErrorRate(self.ignore_tokens) - self.test_cer = CharacterErrorRate(self.ignore_tokens) + """Post init configuration.""" + self.start_index = self.mapping.get_index(self.start_token) + self.end_index = self.mapping.get_index(self.end_token) + self.pad_index = self.mapping.get_index(self.pad_token) + self.ignore_indices = [self.start_index, self.end_index, self.pad_index] + self.val_cer = CharacterErrorRate(self.ignore_indices) + self.test_cer = CharacterErrorRate(self.ignore_indices) def forward(self, data: Tensor) -> Tensor: """Forward pass with the transformer network.""" - return self.network.predict(data) + return self.predict(data) def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: """Training step.""" @@ -38,17 +51,64 @@ class TransformerLitModel(BaseLitModel): """Validation step.""" data, targets = batch + # Compute the loss. logits = self.network(data, targets[:-1]) loss = self.loss_fn(logits, targets[1:]) self.log("val/loss", loss, prog_bar=True) - pred = self.network.predict(data) + # Get the token prediction. + pred = self(data) self.val_cer(pred, targets) self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" data, targets = batch - pred = self.network.predict(data) + + # Compute the text prediction. + pred = self(data) self.test_cer(pred, targets) self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) + + def predict(self, x: Tensor) -> Tensor: + """Predicts text in image. + + Args: + x (Tensor): Image(s) to extract text from. + + Shapes: + - x: :math: `(B, H, W)` + - output: :math: `(B, S)` + + Returns: + Tensor: A tensor of token indices of the predictions from the model. + """ + bsz = x.shape[0] + + # Encode image(s) to latent vectors. + z = self.network.encode(x) + + # Create a placeholder matrix for storing outputs from the network + output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) + output[:, 0] = self.start_index + + for i in range(1, self.max_output_len): + context = output[:, :i] # (bsz, i) + logits = self.network.decode(z, context) # (i, bsz, c) + tokens = torch.argmax(logits, dim=-1) # (i, bsz) + output[:, i : i + 1] = tokens[-1:] + + # Early stopping of prediction loop if token is end or padding token. + if ( + output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index + ).all(): + break + + # Set all tokens after end token to pad token. + for i in range(1, self.max_output_len): + idx = ( + output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index + ) + output[idx, i] = self.pad_index + + return output diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py index 0172163..e215e14 100644 --- a/text_recognizer/models/vqvae.py +++ b/text_recognizer/models/vqvae.py @@ -34,8 +34,6 @@ class VQVAELitModel(BaseLitModel): loss = self.loss_fn(reconstructions, data) loss += vq_loss self.log("val/loss", loss, prog_bar=True) - title = "val_pred_examples" - self._log_prediction(data, reconstructions, title) def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Test step.""" @@ -43,5 +41,4 @@ class VQVAELitModel(BaseLitModel): reconstructions, vq_loss = self.network(data) loss = self.loss_fn(reconstructions, data) loss += vq_loss - title = "test_pred_examples" - self._log_prediction(data, reconstructions, title) + self.log("test/loss", loss) |