summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/model/__init__.py1
-rw-r--r--text_recognizer/model/base.py (renamed from text_recognizer/models/base.py)5
-rw-r--r--text_recognizer/model/greedy_decoder.py58
-rw-r--r--text_recognizer/model/transformer.py (renamed from text_recognizer/models/transformer.py)24
-rw-r--r--text_recognizer/models/__init__.py2
-rw-r--r--text_recognizer/models/greedy_decoder.py51
6 files changed, 68 insertions, 73 deletions
diff --git a/text_recognizer/model/__init__.py b/text_recognizer/model/__init__.py
new file mode 100644
index 0000000..1982daf
--- /dev/null
+++ b/text_recognizer/model/__init__.py
@@ -0,0 +1 @@
+"""PyTorch Lightning models modules."""
diff --git a/text_recognizer/models/base.py b/text_recognizer/model/base.py
index 4dd5cdf..1cff796 100644
--- a/text_recognizer/models/base.py
+++ b/text_recognizer/model/base.py
@@ -5,14 +5,14 @@ import hydra
import torch
from loguru import logger as log
from omegaconf import DictConfig
-from pytorch_lightning import LightningModule
+import pytorch_lightning as L
from torch import nn, Tensor
from torchmetrics import Accuracy
from text_recognizer.data.tokenizer import Tokenizer
-class LitBase(LightningModule):
+class LitBase(L.LightningModule):
"""Abstract PyTorch Lightning class."""
def __init__(
@@ -41,7 +41,6 @@ class LitBase(LightningModule):
epoch: int,
batch_idx: int,
optimizer: Type[torch.optim.Optimizer],
- optimizer_idx: int,
) -> None:
"""Optimal way to set grads to zero."""
optimizer.zero_grad(set_to_none=True)
diff --git a/text_recognizer/model/greedy_decoder.py b/text_recognizer/model/greedy_decoder.py
new file mode 100644
index 0000000..5cbbb66
--- /dev/null
+++ b/text_recognizer/model/greedy_decoder.py
@@ -0,0 +1,58 @@
+"""Greedy decoder."""
+from typing import Type
+from text_recognizer.data.tokenizer import Tokenizer
+import torch
+from torch import nn, Tensor
+
+
+class GreedyDecoder:
+ def __init__(
+ self,
+ network: Type[nn.Module],
+ tokenizer: Tokenizer,
+ max_output_len: int = 682,
+ ) -> None:
+ self.network = network
+ self.start_index = tokenizer.start_index
+ self.end_index = tokenizer.end_index
+ self.pad_index = tokenizer.pad_index
+ self.max_output_len = max_output_len
+
+ def __call__(self, x: Tensor) -> Tensor:
+ bsz = x.shape[0]
+
+ # Encode image(s) to latent vectors.
+ img_features = self.network.encode(x)
+
+ # Create a placeholder matrix for storing outputs from the network
+ indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
+ indecies[:, 0] = self.start_index
+
+ try:
+ for i in range(1, self.max_output_len):
+ tokens = indecies[:, :i] # (B, Sy)
+ logits = self.network.decode(tokens, img_features) # (B, C, Sy)
+ indecies_ = torch.argmax(logits, dim=1) # (B, Sy)
+ indecies[:, i : i + 1] = indecies_[:, -1:]
+
+ # Early stopping of prediction loop if token is end or padding token.
+ if (
+ (indecies[:, i - 1] == self.end_index)
+ | (indecies[:, i - 1] == self.pad_index)
+ ).all():
+ break
+
+ # Set all tokens after end token to pad token.
+ for i in range(1, self.max_output_len):
+ idx = (indecies[:, i - 1] == self.end_index) | (
+ indecies[:, i - 1] == self.pad_index
+ )
+ indecies[idx, i] = self.pad_index
+
+ return indecies
+ except Exception:
+ # TODO: investigate this error more
+ print(x.shape)
+ # print(indecies)
+ print(indecies.shape)
+ print(img_features.shape)
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/model/transformer.py
index bbfaac0..23b2a3a 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/model/transformer.py
@@ -1,6 +1,6 @@
-"""Lightning model for base Transformers."""
+"""Lightning model for transformer networks."""
from typing import Callable, Optional, Sequence, Tuple, Type
-from text_recognizer.models.greedy_decoder import GreedyDecoder
+from text_recognizer.model.greedy_decoder import GreedyDecoder
import torch
from omegaconf import DictConfig
@@ -8,12 +8,10 @@ from torch import nn, Tensor
from torchmetrics import CharErrorRate, WordErrorRate
from text_recognizer.data.tokenizer import Tokenizer
-from text_recognizer.models.base import LitBase
+from text_recognizer.model.base import LitBase
class LitTransformer(LitBase):
- """A PyTorch Lightning model for transformer networks."""
-
def __init__(
self,
network: Type[nn.Module],
@@ -51,22 +49,18 @@ class LitTransformer(LitBase):
data, targets = batch
logits = self.teacher_forward(data, targets[:, :-1])
loss = self.loss_fn(logits, targets[:, 1:])
- self.log("train/loss", loss)
+ self.log("train/loss", loss, prog_bar=True)
return loss
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Validation step."""
data, targets = batch
-
- logits = self.teacher_forward(data, targets[:, :-1])
- loss = self.loss_fn(logits, targets[:, 1:])
- preds = self.predict(data)
+ preds = self(data)
pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets)
self.val_acc(preds, targets)
self.val_cer(pred_text, target_text)
self.val_wer(pred_text, target_text)
- self.log("val/loss", loss, on_step=False, on_epoch=True)
self.log("val/acc", self.val_acc, on_step=False, on_epoch=True)
self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
self.log("val/wer", self.val_wer, on_step=False, on_epoch=True, prog_bar=True)
@@ -74,25 +68,21 @@ class LitTransformer(LitBase):
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Test step."""
data, targets = batch
-
- logits = self.teacher_forward(data, targets[:, :-1])
- loss = self.loss_fn(logits, targets[:, 1:])
preds = self(data)
pred_text, target_text = self._to_tokens(preds), self._to_tokens(targets)
self.test_acc(preds, targets)
self.test_cer(pred_text, target_text)
self.test_wer(pred_text, target_text)
- self.log("test/loss", loss, on_step=False, on_epoch=True)
self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)
self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)
self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True)
def _to_tokens(
self,
- indecies: Tensor,
+ indices: Tensor,
) -> Sequence[str]:
- return [self.tokenizer.decode(i) for i in indecies]
+ return [self.tokenizer.decode(i) for i in indices]
@torch.no_grad()
def predict(self, x: Tensor) -> Tensor:
diff --git a/text_recognizer/models/__init__.py b/text_recognizer/models/__init__.py
deleted file mode 100644
index cc02487..0000000
--- a/text_recognizer/models/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-"""PyTorch Lightning models modules."""
-from text_recognizer.models.transformer import LitTransformer
diff --git a/text_recognizer/models/greedy_decoder.py b/text_recognizer/models/greedy_decoder.py
deleted file mode 100644
index 9d2f192..0000000
--- a/text_recognizer/models/greedy_decoder.py
+++ /dev/null
@@ -1,51 +0,0 @@
-"""Greedy decoder."""
-from typing import Type
-from text_recognizer.data.tokenizer import Tokenizer
-import torch
-from torch import nn, Tensor
-
-
-class GreedyDecoder:
- def __init__(
- self,
- network: Type[nn.Module],
- tokenizer: Tokenizer,
- max_output_len: int = 682,
- ) -> None:
- self.network = network
- self.start_index = tokenizer.start_index
- self.end_index = tokenizer.end_index
- self.pad_index = tokenizer.pad_index
- self.max_output_len = max_output_len
-
- def __call__(self, x: Tensor) -> Tensor:
- bsz = x.shape[0]
-
- # Encode image(s) to latent vectors.
- img_features = self.network.encode(x)
-
- # Create a placeholder matrix for storing outputs from the network
- indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
- indecies[:, 0] = self.start_index
-
- for Sy in range(1, self.max_output_len):
- tokens = indecies[:, :Sy] # (B, Sy)
- logits = self.network.decode(tokens, img_features) # (B, C, Sy)
- indecies_ = torch.argmax(logits, dim=1) # (B, Sy)
- indecies[:, Sy : Sy + 1] = indecies_[:, -1:]
-
- # Early stopping of prediction loop if token is end or padding token.
- if (
- (indecies[:, Sy - 1] == self.end_index)
- | (indecies[:, Sy - 1] == self.pad_index)
- ).all():
- break
-
- # Set all tokens after end token to pad token.
- for Sy in range(1, self.max_output_len):
- idx = (indecies[:, Sy - 1] == self.end_index) | (
- indecies[:, Sy - 1] == self.pad_index
- )
- indecies[idx, Sy] = self.pad_index
-
- return indecies