From 1a193f09bd9199609c30a005dbd28f587dce2606 Mon Sep 17 00:00:00 2001 From: Gustaf Rydholm Date: Sat, 18 Sep 2021 17:43:44 +0200 Subject: Add Vq transfomer model --- text_recognizer/models/vq_transformer.py | 99 ++++++++++++++++++++++++++++++ text_recognizer/networks/vq_transformer.py | 67 ++++++++++++++++---- 2 files changed, 154 insertions(+), 12 deletions(-) create mode 100644 text_recognizer/models/vq_transformer.py (limited to 'text_recognizer') diff --git a/text_recognizer/models/vq_transformer.py b/text_recognizer/models/vq_transformer.py new file mode 100644 index 0000000..71ca2ef --- /dev/null +++ b/text_recognizer/models/vq_transformer.py @@ -0,0 +1,99 @@ +"""PyTorch Lightning model for base Transformers.""" +from typing import Tuple, Type, Set + +import attr +import torch +from torch import Tensor + +from text_recognizer.models.metrics import CharacterErrorRate +from text_recognizer.models.transformer import TransformerLitModel + + +@attr.s(auto_attribs=True, eq=False) +class VqTransformerLitModel(TransformerLitModel): + """A PyTorch Lightning model for transformer networks.""" + + def forward(self, data: Tensor) -> Tensor: + """Forward pass with the transformer network.""" + return self.predict(data) + + def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Tensor: + """Training step.""" + data, targets = batch + logits, commitment_loss = self.network(data, targets[:-1]) + loss = self.loss_fn(logits, targets[1:]) + commitment_loss + self.log("train/loss", loss) + self.log("train/commitment_loss", commitment_loss) + return loss + + def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Validation step.""" + data, targets = batch + + # Compute the loss. + logits, commitment_loss = self.network(data, targets[:-1]) + loss = self.loss_fn(logits, targets[1:]) + commitment_loss + self.log("val/loss", loss, prog_bar=True) + self.log("val/commitment_loss", commitment_loss) + + # Get the token prediction. + pred = self(data) + self.val_cer(pred, targets) + self.log("val/cer", self.val_cer, on_step=False, on_epoch=True, prog_bar=True) + self.test_acc(pred, targets) + self.log("val/acc", self.test_acc, on_step=False, on_epoch=True) + + def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> None: + """Test step.""" + data, targets = batch + + # Compute the text prediction. + pred = self(data) + self.test_cer(pred, targets) + self.log("test/cer", self.test_cer, on_step=False, on_epoch=True, prog_bar=True) + self.test_acc(pred, targets) + self.log("test/acc", self.test_acc, on_step=False, on_epoch=True) + + def predict(self, x: Tensor) -> Tensor: + """Predicts text in image. + + Args: + x (Tensor): Image(s) to extract text from. + + Shapes: + - x: :math: `(B, H, W)` + - output: :math: `(B, S)` + + Returns: + Tensor: A tensor of token indices of the predictions from the model. + """ + bsz = x.shape[0] + + # Encode image(s) to latent vectors. + z, _ = self.network.encode(x) + + # Create a placeholder matrix for storing outputs from the network + output = torch.ones((bsz, self.max_output_len), dtype=torch.long).to(x.device) + output[:, 0] = self.start_index + + for Sy in range(1, self.max_output_len): + context = output[:, :Sy] # (B, Sy) + logits = self.network.decode(z, context) # (B, Sy, C) + tokens = torch.argmax(logits, dim=-1) # (B, Sy) + output[:, Sy : Sy + 1] = tokens[:, -1:] + + # Early stopping of prediction loop if token is end or padding token. + if ( + (output[:, Sy - 1] == self.end_index) + | (output[:, Sy - 1] == self.pad_index) + ).all(): + break + + # Set all tokens after end token to pad token. + for Sy in range(1, self.max_output_len): + idx = (output[:, Sy - 1] == self.end_index) | ( + output[:, Sy - 1] == self.pad_index + ) + output[idx, Sy] = self.pad_index + + return output diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py index 0433863..69f68fd 100644 --- a/text_recognizer/networks/vq_transformer.py +++ b/text_recognizer/networks/vq_transformer.py @@ -1,5 +1,6 @@ """Vector quantized encoder, transformer decoder.""" -from typing import Tuple +from pathlib import Path +from typing import Tuple, Optional import torch from torch import Tensor @@ -22,7 +23,8 @@ class VqTransformer(ConvTransformer): pad_index: Tensor, encoder: VQVAE, decoder: Decoder, - pretrained_encoder_path: str, + no_grad: bool, + pretrained_encoder_path: Optional[str] = None, ) -> None: super().__init__( input_dims=input_dims, @@ -34,18 +36,36 @@ class VqTransformer(ConvTransformer): encoder=encoder, decoder=decoder, ) - self.pretrained_encoder_path = pretrained_encoder_path - # For typing self.encoder: VQVAE - def setup_encoder(self) -> None: + self.no_grad = no_grad + + if pretrained_encoder_path is not None: + self.pretrained_encoder_path = ( + Path(__file__).resolve().parents[2] / pretrained_encoder_path + ) + self._setup_encoder() + else: + self.pretrained_encoder_path = None + + def _load_pretrained_encoder(self) -> None: + self.encoder.load_state_dict( + torch.load(self.pretrained_encoder_path)["state_dict"]["network"] + ) + + def _setup_encoder(self) -> None: """Remove unecessary layers.""" - # TODO: load pretrained vqvae + self._load_pretrained_encoder() del self.encoder.decoder - del self.encoder.post_codebook_conv + # del self.encoder.post_codebook_conv + + def _encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: + z_e = self.encoder.encode(x) + z_q, commitment_loss = self.encoder.quantize(z_e) + return z_q, commitment_loss - def encode(self, x: Tensor) -> Tensor: + def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: """Encodes an image into a discrete (VQ) latent representation. Args: @@ -62,12 +82,35 @@ class VqTransformer(ConvTransformer): Returns: Tensor: A Latent embedding of the image. """ - with torch.no_grad(): - z_e = self.encoder.encode(x) - z_q, _ = self.encoder.quantize(z_e) + if self.no_grad: + with torch.no_grad(): + z_q, commitment_loss = self._encode(x) + else: + z_q, commitment_loss = self._encode(x) z = self.latent_encoder(z_q) # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E] z = z.permute(0, 2, 1) - return z + return z, commitment_loss + + def forward(self, x: Tensor, context: Tensor) -> Tensor: + """Encodes images into word piece logtis. + + Args: + x (Tensor): Input image(s). + context (Tensor): Target word embeddings. + + Shapes: + - x: :math: `(B, C, H, W)` + - context: :math: `(B, Sy, T)` + + where B is the batch size, C is the number of input channels, H is + the image height and W is the image width. + + Returns: + Tensor: Sequence of logits. + """ + z, commitment_loss = self.encode(x) + logits = self.decode(z, context) + return logits, commitment_loss -- cgit v1.2.3-70-g09d2