summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-17 22:44:10 +0100
committerGustaf Rydholm <gustaf.rydholm@gmail.com>2021-11-17 22:44:10 +0100
commita9a7c8ab3314e36eed7781552bf1b383a624eea2 (patch)
tree89cd4cb6abf8370150a94d9fb1a8ca2faa7345b8
parent8b06e0ff3185436848c08bf04d730d7e5212e0e5 (diff)
Update VqTransformer
-rw-r--r--text_recognizer/networks/vq_transformer.py69
1 files changed, 16 insertions, 53 deletions
diff --git a/text_recognizer/networks/vq_transformer.py b/text_recognizer/networks/vq_transformer.py
index 1c85bf9..a2bd81b 100644
--- a/text_recognizer/networks/vq_transformer.py
+++ b/text_recognizer/networks/vq_transformer.py
@@ -1,14 +1,10 @@
"""Vector quantized encoder, transformer decoder."""
-from pathlib import Path
-from typing import OrderedDict, Tuple
+from typing import Optional, Tuple, Type
-from omegaconf import OmegaConf
-from hydra.utils import instantiate
-import torch
-from torch import Tensor
+from torch import nn, Tensor
-from text_recognizer.networks.vqvae.vqvae import VQVAE
from text_recognizer.networks.conv_transformer import ConvTransformer
+from text_recognizer.networks.quantizer.quantizer import VectorQuantizer
from text_recognizer.networks.transformer.layers import Decoder
@@ -18,57 +14,26 @@ class VqTransformer(ConvTransformer):
def __init__(
self,
input_dims: Tuple[int, int, int],
- encoder_dim: int,
hidden_dim: int,
- dropout_rate: float,
num_classes: int,
pad_index: Tensor,
+ encoder: nn.Module,
decoder: Decoder,
- no_grad: bool,
- pretrained_encoder_path: str,
+ pixel_pos_embedding: Type[nn.Module],
+ quantizer: VectorQuantizer,
+ token_pos_embedding: Optional[Type[nn.Module]] = None,
) -> None:
- # For typing
- self.encoder: VQVAE = None
- self.no_grad = no_grad
-
super().__init__(
input_dims=input_dims,
- encoder_dim=encoder_dim,
hidden_dim=hidden_dim,
- dropout_rate=dropout_rate,
num_classes=num_classes,
pad_index=pad_index,
- encoder=self.encoder,
+ encoder=encoder,
decoder=decoder,
+ pixel_pos_embedding=pixel_pos_embedding,
+ token_pos_embedding=token_pos_embedding,
)
- self._setup_encoder(pretrained_encoder_path)
-
- def _load_state_dict(self, path: Path) -> OrderedDict:
- weights_path = list((path / "checkpoints").glob("epoch=*.ckpt"))[0]
- renamed_state_dict = OrderedDict()
- state_dict = torch.load(weights_path)["state_dict"]
- for key in state_dict.keys():
- if "network" in key:
- new_key = key.removeprefix("network.")
- renamed_state_dict[new_key] = state_dict[key]
- del state_dict
- return renamed_state_dict
-
- def _setup_encoder(self, pretrained_encoder_path: str,) -> None:
- """Load encoder module."""
- path = Path(__file__).resolve().parents[2] / pretrained_encoder_path
- with open(path / "config.yaml") as f:
- cfg = OmegaConf.load(f)
- state_dict = self._load_state_dict(path)
- self.encoder = instantiate(cfg.network)
- self.encoder.load_state_dict(state_dict)
- del self.encoder.decoder
-
- def _encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
- z_e = self.encoder.encode(x)
- z_q, commitment_loss = self.encoder.quantize(z_e)
- z = self.encoder.post_codebook_conv(z_q)
- return z, commitment_loss
+ self.quantizer = quantizer
def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""Encodes an image into a discrete (VQ) latent representation.
@@ -87,13 +52,11 @@ class VqTransformer(ConvTransformer):
Returns:
Tensor: A Latent embedding of the image.
"""
- if self.no_grad:
- with torch.no_grad():
- z_q, commitment_loss = self._encode(x)
- else:
- z_q, commitment_loss = self._encode(x)
-
- z = self.latent_encoder(z_q)
+ z = self.encoder(x)
+ z = self.conv(z)
+ z, _, commitment_loss = self.quantizer(z)
+ z = self.pixel_pos_embedding(z)
+ z = z.flatten(start_dim=2)
# Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
z = z.permute(0, 2, 1)