summaryrefslogtreecommitdiff
path: root/text_recognizer/models
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/models')
-rw-r--r--text_recognizer/models/base.py4
-rw-r--r--text_recognizer/models/metrics.py8
-rw-r--r--text_recognizer/models/transformer.py80
-rw-r--r--text_recognizer/models/vqvae.py5
4 files changed, 76 insertions, 21 deletions
diff --git a/text_recognizer/models/base.py b/text_recognizer/models/base.py
index 3e02261..dfb4ca4 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/models/base.py
@@ -11,8 +11,6 @@ from torch import nn
from torch import Tensor
import torchmetrics
-from text_recognizer.networks.base import BaseNetwork
-
@attr.s
class BaseLitModel(LightningModule):
@@ -21,7 +19,7 @@ class BaseLitModel(LightningModule):
def __attrs_pre_init__(self) -> None:
super().__init__()
- network: Type[BaseNetwork] = attr.ib()
+ network: Type[nn.Module] = attr.ib()
criterion_config: DictConfig = attr.ib(converter=DictConfig)
optimizer_config: DictConfig = attr.ib(converter=DictConfig)
lr_scheduler_config: DictConfig = attr.ib(converter=DictConfig)
diff --git a/text_recognizer/models/metrics.py b/text_recognizer/models/metrics.py
index 4117ae2..9793157 100644
--- a/text_recognizer/models/metrics.py
+++ b/text_recognizer/models/metrics.py
@@ -1,5 +1,5 @@
"""Character Error Rate (CER)."""
-from typing import Set, Sequence
+from typing import Set
import attr
import editdistance
@@ -12,7 +12,7 @@ from torchmetrics import Metric
class CharacterErrorRate(Metric):
"""Character error rate metric, computed using Levenshtein distance."""
- ignore_tokens: Set = attr.ib(converter=set)
+ ignore_indices: Set = attr.ib(converter=set)
error: Tensor = attr.ib(init=False)
total: Tensor = attr.ib(init=False)
@@ -25,8 +25,8 @@ class CharacterErrorRate(Metric):
"""Update CER."""
bsz = preds.shape[0]
for index in range(bsz):
- pred = [p for p in preds[index].tolist() if p not in self.ignore_tokens]
- target = [t for t in targets[index].tolist() if t not in self.ignore_tokens]
+ pred = [p for p in preds[index].tolist() if p not in self.ignore_indices]
+ target = [t for t in targets[index].tolist() if t not in self.ignore_indices]
distance = editdistance.distance(pred, target)
error = distance / max(len(pred), len(target))
self.error += error
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index f5cb491..7a9d566 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -1,11 +1,11 @@
"""PyTorch Lightning model for base Transformers."""
-from typing import Dict, List, Optional, Sequence, Union, Tuple, Type
+from typing import Sequence, Tuple, Type
import attr
-import hydra
-from omegaconf import DictConfig
-from torch import nn, Tensor
+import torch
+from torch import Tensor
+from text_recognizer.data.mappings import AbstractMapping
from text_recognizer.models.metrics import CharacterErrorRate
from text_recognizer.models.base import BaseLitModel
@@ -13,18 +13,31 @@ from text_recognizer.models.base import BaseLitModel
@attr.s(auto_attribs=True)
class TransformerLitModel(BaseLitModel):
"""A PyTorch Lightning model for transformer networks."""
+ mapping: Type[AbstractMapping] = attr.ib()
+ start_token: str = attr.ib()
+ end_token: str = attr.ib()
+ pad_token: str = attr.ib()
- ignore_tokens: Sequence[str] = attr.ib(default=("<s>", "<e>", "<p>",))
+ start_index: Tensor = attr.ib(init=False)
+ end_index: Tensor = attr.ib(init=False)
+ pad_index: Tensor = attr.ib(init=False)
+
+ ignore_indices: Sequence[str] = attr.ib(init=False)
val_cer: CharacterErrorRate = attr.ib(init=False)
test_cer: CharacterErrorRate = attr.ib(init=False)
def __attrs_post_init__(self) -> None:
- self.val_cer = CharacterErrorRate(self.ignore_tokens)
- self.test_cer = CharacterErrorRate(self.ignore_tokens)
+ """Post init configuration."""
+ self.start_index = self.mapping.get_index(self.start_token)
+ self.end_index = self.mapping.get_index(self.end_token)
+ self.pad_index = self.mapping.get_index(self.pad_token)
+ self.ignore_indices = [self.start_index, self.end_index, self.pad_index]
+ self.val_cer = CharacterErrorRate(self.ignore_indices)
+ self.test_cer = CharacterErrorRate(self.ignore_indices)
def forward(self, data: Tensor) -> Tensor:
"""Forward pass with the transformer network."""
- return self.network.predict(data)
+ return self.predict(data)
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
"""Training step."""
@@ -38,17 +51,64 @@ class TransformerLitModel(BaseLitModel):
"""Validation step."""
data, targets = batch
+ # Compute the loss.
logits = self.network(data, targets[:-1])
loss = self.loss_fn(logits, targets[1:])
self.log("val/loss", loss, prog_bar=True)
- pred = self.network.predict(data)
+ # Get the token prediction.
+ pred = self(data)
self.val_cer(pred, targets)
self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Test step."""
data, targets = batch
- pred = self.network.predict(data)
+
+ # Compute the text prediction.
+ pred = self(data)
self.test_cer(pred, targets)
self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)
+
+ def predict(self, x: Tensor) -> Tensor:
+ """Predicts text in image.
+
+ Args:
+ x (Tensor): Image(s) to extract text from.
+
+ Shapes:
+ - x: :math: `(B, H, W)`
+ - output: :math: `(B, S)`
+
+ Returns:
+ Tensor: A tensor of token indices of the predictions from the model.
+ """
+ bsz = x.shape[0]
+
+ # Encode image(s) to latent vectors.
+ z = self.network.encode(x)
+
+ # Create a placeholder matrix for storing outputs from the network
+ output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
+ output[:, 0] = self.start_index
+
+ for i in range(1, self.max_output_len):
+ context = output[:, :i] # (bsz, i)
+ logits = self.network.decode(z, context) # (i, bsz, c)
+ tokens = torch.argmax(logits, dim=-1) # (i, bsz)
+ output[:, i : i + 1] = tokens[-1:]
+
+ # Early stopping of prediction loop if token is end or padding token.
+ if (
+ output[:, i - 1] == self.end_index | output[: i - 1] == self.pad_index
+ ).all():
+ break
+
+ # Set all tokens after end token to pad token.
+ for i in range(1, self.max_output_len):
+ idx = (
+ output[:, i - 1] == self.end_index | output[:, i - 1] == self.pad_index
+ )
+ output[idx, i] = self.pad_index
+
+ return output
diff --git a/text_recognizer/models/vqvae.py b/text_recognizer/models/vqvae.py
index 0172163..e215e14 100644
--- a/text_recognizer/models/vqvae.py
+++ b/text_recognizer/models/vqvae.py
@@ -34,8 +34,6 @@ class VQVAELitModel(BaseLitModel):
loss = self.loss_fn(reconstructions, data)
loss += vq_loss
self.log("val/loss", loss, prog_bar=True)
- title = "val_pred_examples"
- self._log_prediction(data, reconstructions, title)
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Test step."""
@@ -43,5 +41,4 @@ class VQVAELitModel(BaseLitModel):
reconstructions, vq_loss = self.network(data)
loss = self.loss_fn(reconstructions, data)
loss += vq_loss
- title = "test_pred_examples"
- self._log_prediction(data, reconstructions, title)
+ self.log("test/loss", loss)