summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vq_transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/vq_transformer.py')
-rw-r--r--text_recognizer/networks/vq_transformer.py67
1 files changed, 55 insertions, 12 deletions
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py
index 0433863..69f68fd 100644
--- a/text_recognizer/networks/vq_transformer.py
+++ b/text_recognizer/networks/vq_transformer.py
@@ -1,5 +1,6 @@
"""Vector quantized encoder, transformer decoder."""
-from typing import Tuple
+from pathlib import Path
+from typing import Tuple, Optional
import torch
from torch import Tensor
@@ -22,7 +23,8 @@ class VqTransformer(ConvTransformer):
pad_index: Tensor,
encoder: VQVAE,
decoder: Decoder,
- pretrained_encoder_path: str,
+ no_grad: bool,
+ pretrained_encoder_path: Optional[str] = None,
) -> None:
super().__init__(
input_dims=input_dims,
@@ -34,18 +36,36 @@ class VqTransformer(ConvTransformer):
encoder=encoder,
decoder=decoder,
)
- self.pretrained_encoder_path = pretrained_encoder_path
-
# For typing
self.encoder: VQVAE
- def setup_encoder(self) -> None:
+ self.no_grad = no_grad
+
+ if pretrained_encoder_path is not None:
+ self.pretrained_encoder_path = (
+ Path(__file__).resolve().parents[2] / pretrained_encoder_path
+ )
+ self._setup_encoder()
+ else:
+ self.pretrained_encoder_path = None
+
+ def _load_pretrained_encoder(self) -> None:
+ self.encoder.load_state_dict(
+ torch.load(self.pretrained_encoder_path)["state_dict"]["network"]
+ )
+
+ def _setup_encoder(self) -> None:
"""Remove unecessary layers."""
- # TODO: load pretrained vqvae
+ self._load_pretrained_encoder()
del self.encoder.decoder
- del self.encoder.post_codebook_conv
+ # del self.encoder.post_codebook_conv
+
+ def _encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ z_e = self.encoder.encode(x)
+ z_q, commitment_loss = self.encoder.quantize(z_e)
+ return z_q, commitment_loss
- def encode(self, x: Tensor) -> Tensor:
+ def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""Encodes an image into a discrete (VQ) latent representation.
Args:
@@ -62,12 +82,35 @@ class VqTransformer(ConvTransformer):
Returns:
Tensor: A Latent embedding of the image.
"""
- with torch.no_grad():
- z_e = self.encoder.encode(x)
- z_q, _ = self.encoder.quantize(z_e)
+ if self.no_grad:
+ with torch.no_grad():
+ z_q, commitment_loss = self._encode(x)
+ else:
+ z_q, commitment_loss = self._encode(x)
z = self.latent_encoder(z_q)
# Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
z = z.permute(0, 2, 1)
- return z
+ return z, commitment_loss
+
+ def forward(self, x: Tensor, context: Tensor) -> Tensor:
+ """Encodes images into word piece logtis.
+
+ Args:
+ x (Tensor): Input image(s).
+ context (Tensor): Target word embeddings.
+
+ Shapes:
+ - x: :math: `(B, C, H, W)`
+ - context: :math: `(B, Sy, T)`
+
+ where B is the batch size, C is the number of input channels, H is
+ the image height and W is the image width.
+
+ Returns:
+ Tensor: Sequence of logits.
+ """
+ z, commitment_loss = self.encode(x)
+ logits = self.decode(z, context)
+ return logits, commitment_loss