diff options
Diffstat (limited to 'text_recognizer')
| -rw-r--r-- | text_recognizer/models/transformer.py | 36 | 
1 files changed, 25 insertions, 11 deletions
diff --git a/text_recognizer/models/transformer.py b/text_recognizer/models/transformer.py index f6f10a7..3c38ced 100644 --- a/text_recognizer/models/transformer.py +++ b/text_recognizer/models/transformer.py @@ -40,10 +40,14 @@ class LitTransformer(LitBase):          """Forward pass with the transformer network."""          return self.predict(data) +    def teacher_forward(self, data: Tensor, targets: Tensor) -> Tensor: +        """Non-autoregressive forward pass.""" +        return self.network(data, targets) +      def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor:          """Training step."""          data, targets = batch -        logits = self.network(data, targets[:, :-1]) +        logits = self.teacher_forward(data, targets[:, :-1])          loss = self.loss_fn(logits, targets[:, 1:])          self.log("train/loss", loss)          return loss @@ -51,11 +55,16 @@ class LitTransformer(LitBase):      def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None:          """Validation step."""          data, targets = batch + +        logits = self.teacher_forward(data, targets[:, :-1]) +        loss = self.loss_fn(logits, targets[:, 1:])          preds = self.predict(data)          pred_text, target_text = self._get_text(preds), self._get_text(targets) +          self.val_acc(preds, targets)          self.val_cer(pred_text, target_text)          self.val_wer(pred_text, target_text) +        self.log("val/loss", loss, on_step=False, on_epoch=True)          self.log("val/acc", self.val_acc, on_step=False, on_epoch=True)          self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True)          self.log("val/wer", self.val_wer, on_step=False, on_epoch=True, prog_bar=True) @@ -64,12 +73,15 @@ class LitTransformer(LitBase):          """Test step."""          data, targets = batch -        # Compute the text prediction. +        logits = self.teacher_forward(data, targets[:, :-1]) +        loss = self.loss_fn(logits, targets[:, 1:])          preds = self(data)          pred_text, target_text = self._get_text(preds), self._get_text(targets) +          self.test_acc(preds, targets)          self.test_cer(pred_text, target_text)          self.test_wer(pred_text, target_text) +        self.log("test/loss", loss, on_step=False, on_epoch=True)          self.log("test/acc", self.test_acc, on_step=False, on_epoch=True)          self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True)          self.log("test/wer", self.test_wer, on_step=False, on_epoch=True, prog_bar=True) @@ -103,24 +115,26 @@ class LitTransformer(LitBase):          z = self.network.encode(x)          # Create a placeholder matrix for storing outputs from the network -        output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) -        output[:, 0] = start_index +        indecies = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) +        indecies[:, 0] = start_index          for Sy in range(1, self.max_output_len): -            context = output[:, :Sy]  # (B, Sy) +            context = indecies[:, :Sy]  # (B, Sy)              logits = self.network.decode(z, context)  # (B, C, Sy) -            tokens = torch.argmax(logits, dim=1)  # (B, Sy) -            output[:, Sy : Sy + 1] = tokens[:, -1:] +            indecies_ = torch.argmax(logits, dim=1)  # (B, Sy) +            indecies[:, Sy : Sy + 1] = indecies_[:, -1:]              # Early stopping of prediction loop if token is end or padding token.              if ( -                (output[:, Sy - 1] == end_index) | (output[:, Sy - 1] == pad_index) +                (indecies[:, Sy - 1] == end_index) | (indecies[:, Sy - 1] == pad_index)              ).all():                  break          # Set all tokens after end token to pad token.          for Sy in range(1, self.max_output_len): -            idx = (output[:, Sy - 1] == end_index) | (output[:, Sy - 1] == pad_index) -            output[idx, Sy] = pad_index +            idx = (indecies[:, Sy - 1] == end_index) | ( +                indecies[:, Sy - 1] == pad_index +            ) +            indecies[idx, Sy] = pad_index -        return output +        return indecies  |