diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-30 00:34:39 +0200 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2022-09-30 00:34:39 +0200 |
commit | 9e85e0883f2e921ca9a57cb2fd93ec47a2535d59 (patch) | |
tree | 3f256120f137fd8bf2df3ec33727a2cdf4992455 /text_recognizer | |
parent | d73c52e15b519af764a83378d4eab19fb31985e0 (diff) |
Update lit transformer
Diffstat (limited to 'text_recognizer')
-rw-r--r-- | text_recognizer/models/transformer.py | 36 |
1 files changed, 25 insertions, 11 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index f6f10a7..3c38ced 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -40,10 +40,14 @@ class LitTransformer(LitBase): """Forward pass with the transformer network.""" return self.predict(data) + def teacher_forward(self, data: Tensor, targets: Tensor) -> Tensor: + """Non-autoregressive forward pass.""" + return self.network(data, targets) + def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: """Training step.""" data, targets = batch - logits = self.network(data, targets[:, :-1]) + logits = self.teacher_forward(data, targets[:, :-1]) loss = self.loss_fn(logits, targets[:, 1:]) self.log("train/loss", loss) return loss @@ -51,11 +55,16 @@ class LitTransformer(LitBase): def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: """Validation step.""" data, targets = batch + + logits = self.teacher_forward(data, targets[:, :-1]) + loss = self.loss_fn(logits, targets[:, 1:]) preds = self.predict(data) pred_text, target_text = self._get_text(preds), self._get_text(targets) + self.val_acc(preds, targets) self.val_cer(pred_text, target_text) self.val_wer(pred_text, target_text) + self.log("val/loss", loss, on_step=False, on_epoch=True) self.log("val/acc", self.val_acc, on_step=False, on_epoch=True) self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) self.log("val/wer", self.val_wer, on_step=False, on_epoch=True, prog_bar=True) @@ -64,12 +73,15 @@ class LitTransformer(LitBase): """Test step.""" data, targets = batch - # Compute the text prediction. + logits = self.teacher_forward(data, targets[:, :-1]) + loss = self.loss_fn(logits, targets[:, 1:]) preds = self(data) pred_text, target_text = self._get_text(preds), self._get_text(targets) + self.test_acc(preds, targets) self.test_cer(pred_text, target_text) self.test_wer(pred_text, target_text) + self.log("test/loss", loss, on_step=False, on_epoch=True) self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True) @@ -103,24 +115,26 @@ class LitTransformer(LitBase): z = self.network.encode(x) # Create a placeholder matrix for storing outputs from the network - output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) - output[:, 0] = start_index + indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) + indecies[:, 0] = start_index for Sy in range(1, self.max_output_len): - context = output[:, :Sy] # (B, Sy) + context = indecies[:, :Sy] # (B, Sy) logits = self.network.decode(z, context) # (B, C, Sy) - tokens = torch.argmax(logits, dim=1) # (B, Sy) - output[:, Sy : Sy + 1] = tokens[:, -1:] + indecies_ = torch.argmax(logits, dim=1) # (B, Sy) + indecies[:, Sy : Sy + 1] = indecies_[:, -1:] # Early stopping of prediction loop if token is end or padding token. if ( - (output[:, Sy - 1] == end_index) | (output[:, Sy - 1] == pad_index) + (indecies[:, Sy - 1] == end_index) | (indecies[:, Sy - 1] == pad_index) ).all(): break # Set all tokens after end token to pad token. for Sy in range(1, self.max_output_len): - idx = (output[:, Sy - 1] == end_index) | (output[:, Sy - 1] == pad_index) - output[idx, Sy] = pad_index + idx = (indecies[:, Sy - 1] == end_index) | ( + indecies[:, Sy - 1] == pad_index + ) + indecies[idx, Sy] = pad_index - return output + return indecies |