1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
|
"""Lightning model for base Transformers."""
from collections.abc import Sequence
from typing import Optional, Tuple, Type
import torch
from omegaconf import DictConfig
from torch import nn, Tensor
from torchmetrics import CharErrorRate, WordErrorRate
from text_recognizer.data.tokenizer import Tokenizer
from text_recognizer.models.base import LitBase
class LitTransformer(LitBase):
"""A PyTorch Lightning model for transformer networks."""
def __init__(
self,
network: Type[nn.Module],
loss_fn: Type[nn.Module],
optimizer_config: DictConfig,
tokenizer: Tokenizer,
lr_scheduler_config: Optional[DictConfig] = None,
max_output_len: int = 682,
) -> None:
super().__init__(
network,
loss_fn,
optimizer_config,
lr_scheduler_config,
tokenizer,
)
self.max_output_len = max_output_len
self.val_cer = CharErrorRate()
self.test_cer = CharErrorRate()
self.val_wer = WordErrorRate()
self.test_wer = WordErrorRate()
def forward(self, data: Tensor) -> Tensor:
"""Forward pass with the transformer network."""
return self.predict(data)
def teacher_forward(self, data: Tensor, targets: Tensor) -> Tensor:
"""Non-autoregressive forward pass."""
return self.network(data, targets)
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
"""Training step."""
data, targets = batch
logits = self.teacher_forward(data, targets[:, :-1])
loss = self.loss_fn(logits, targets[:, 1:])
self.log("train/loss", loss)
return loss
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Validation step."""
data, targets = batch
logits = self.teacher_forward(data, targets[:, :-1])
loss = self.loss_fn(logits, targets[:, 1:])
preds = self.predict(data)
pred_text, target_text = self._get_text(preds), self._get_text(targets)
self.val_acc(preds, targets)
self.val_cer(pred_text, target_text)
self.val_wer(pred_text, target_text)
self.log("val/loss", loss, on_step=False, on_epoch=True)
self.log("val/acc", self.val_acc, on_step=False, on_epoch=True)
self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
self.log("val/wer", self.val_wer, on_step=False, on_epoch=True, prog_bar=True)
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Test step."""
data, targets = batch
logits = self.teacher_forward(data, targets[:, :-1])
loss = self.loss_fn(logits, targets[:, 1:])
preds = self(data)
pred_text, target_text = self._get_text(preds), self._get_text(targets)
self.test_acc(preds, targets)
self.test_cer(pred_text, target_text)
self.test_wer(pred_text, target_text)
self.log("test/loss", loss, on_step=False, on_epoch=True)
self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)
self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)
self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True)
def _get_text(
self,
xs: Tensor,
) -> Tuple[Sequence[str], Sequence[str]]:
return [self.tokenizer.decode(x) for x in xs]
@torch.no_grad()
def predict(self, x: Tensor) -> Tensor:
"""Predicts text in image.
Args:
x (Tensor): Image(s) to extract text from.
Shapes:
- x: :math: `(B, H, W)`
- output: :math: `(B, S)`
Returns:
Tensor: A tensor of token indices of the predictions from the model.
"""
start_index = self.tokenizer.start_index
end_index = self.tokenizer.end_index
pad_index = self.tokenizer.pad_index
bsz = x.shape[0]
# Encode image(s) to latent vectors.
img_features = self.network.encode(x)
# Create a placeholder matrix for storing outputs from the network
indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
indecies[:, 0] = start_index
for Sy in range(1, self.max_output_len):
tokens = indecies[:, :Sy] # (B, Sy)
logits = self.network.decode(tokens, img_features) # (B, C, Sy)
indecies_ = torch.argmax(logits, dim=1) # (B, Sy)
indecies[:, Sy : Sy + 1] = indecies_[:, -1:]
# Early stopping of prediction loop if token is end or padding token.
if (
(indecies[:, Sy - 1] == end_index) | (indecies[:, Sy - 1] == pad_index)
).all():
break
# Set all tokens after end token to pad token.
for Sy in range(1, self.max_output_len):
idx = (indecies[:, Sy - 1] == end_index) | (
indecies[:, Sy - 1] == pad_index
)
indecies[idx, Sy] = pad_index
return indecies
|