summaryrefslogtreecommitdiff
path: root/text_recognizer
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-30 00:34:39 +0200
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2022-09-30 00:34:39 +0200
commit9e85e0883f2e921ca9a57cb2fd93ec47a2535d59 (patch)
tree3f256120f137fd8bf2df3ec33727a2cdf4992455 /text_recognizer
parentd73c52e15b519af764a83378d4eab19fb31985e0 (diff)
Update lit transformer
Diffstat (limited to 'text_recognizer')
-rw-r--r--text_recognizer/models/transformer.py36
1 files changed, 25 insertions, 11 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py
index f6f10a7..3c38ced 100644
--- a/text_recognizer/models/transformer.py
+++ b/text_recognizer/models/transformer.py
@@ -40,10 +40,14 @@ class LitTransformer(LitBase):
"""Forward pass with the transformer network."""
return self.predict(data)
+ def teacher_forward(self, data: Tensor, targets: Tensor) -> Tensor:
+ """Non-autoregressive forward pass."""
+ return self.network(data, targets)
+
def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:
"""Training step."""
data, targets = batch
- logits = self.network(data, targets[:, :-1])
+ logits = self.teacher_forward(data, targets[:, :-1])
loss = self.loss_fn(logits, targets[:, 1:])
self.log("train/loss", loss)
return loss
@@ -51,11 +55,16 @@ class LitTransformer(LitBase):
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:
"""Validation step."""
data, targets = batch
+
+ logits = self.teacher_forward(data, targets[:, :-1])
+ loss = self.loss_fn(logits, targets[:, 1:])
preds = self.predict(data)
pred_text, target_text = self._get_text(preds), self._get_text(targets)
+
self.val_acc(preds, targets)
self.val_cer(pred_text, target_text)
self.val_wer(pred_text, target_text)
+ self.log("val/loss", loss, on_step=False, on_epoch=True)
self.log("val/acc", self.val_acc, on_step=False, on_epoch=True)
self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)
self.log("val/wer", self.val_wer, on_step=False, on_epoch=True, prog_bar=True)
@@ -64,12 +73,15 @@ class LitTransformer(LitBase):
"""Test step."""
data, targets = batch
- # Compute the text prediction.
+ logits = self.teacher_forward(data, targets[:, :-1])
+ loss = self.loss_fn(logits, targets[:, 1:])
preds = self(data)
pred_text, target_text = self._get_text(preds), self._get_text(targets)
+
self.test_acc(preds, targets)
self.test_cer(pred_text, target_text)
self.test_wer(pred_text, target_text)
+ self.log("test/loss", loss, on_step=False, on_epoch=True)
self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)
self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)
self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True)
@@ -103,24 +115,26 @@ class LitTransformer(LitBase):
z = self.network.encode(x)
# Create a placeholder matrix for storing outputs from the network
- output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
- output[:, 0] = start_index
+ indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device)
+ indecies[:, 0] = start_index
for Sy in range(1, self.max_output_len):
- context = output[:, :Sy] # (B, Sy)
+ context = indecies[:, :Sy] # (B, Sy)
logits = self.network.decode(z, context) # (B, C, Sy)
- tokens = torch.argmax(logits, dim=1) # (B, Sy)
- output[:, Sy : Sy + 1] = tokens[:, -1:]
+ indecies_ = torch.argmax(logits, dim=1) # (B, Sy)
+ indecies[:, Sy : Sy + 1] = indecies_[:, -1:]
# Early stopping of prediction loop if token is end or padding token.
if (
- (output[:, Sy - 1] == end_index) | (output[:, Sy - 1] == pad_index)
+ (indecies[:, Sy - 1] == end_index) | (indecies[:, Sy - 1] == pad_index)
).all():
break
# Set all tokens after end token to pad token.
for Sy in range(1, self.max_output_len):
- idx = (output[:, Sy - 1] == end_index) | (output[:, Sy - 1] == pad_index)
- output[idx, Sy] = pad_index
+ idx = (indecies[:, Sy - 1] == end_index) | (
+ indecies[:, Sy - 1] == pad_index
+ )
+ indecies[idx, Sy] = pad_index
- return output
+ return indecies