1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
|
"""Vector quantized encoder, transformer decoder."""
from pathlib import Path
from typing import OrderedDict, Tuple, Union
from omegaconf import OmegaConf
from hydra.utils import instantiate
import torch
from torch import Tensor
from text_recognizer.networks.vqvae.vqvae import VQVAE
from text_recognizer.networks.conv_transformer import ConvTransformer
from text_recognizer.networks.transformer.layers import Decoder
class VqTransformer(ConvTransformer):
"""Convolutional encoder and transformer decoder network."""
def __init__(
self,
input_dims: Tuple[int, int, int],
encoder_dim: int,
hidden_dim: int,
dropout_rate: float,
num_classes: int,
pad_index: Tensor,
decoder: Decoder,
no_grad: bool,
pretrained_encoder_path: str,
) -> None:
# For typing
self.encoder: VQVAE = None
self.no_grad = no_grad
super().__init__(
input_dims=input_dims,
encoder_dim=encoder_dim,
hidden_dim=hidden_dim,
dropout_rate=dropout_rate,
num_classes=num_classes,
pad_index=pad_index,
encoder=self.encoder,
decoder=decoder,
)
self._setup_encoder(pretrained_encoder_path)
def _load_state_dict(self, path: Path) -> OrderedDict:
weights_path = list((path / "checkpoints").glob("epoch=*.ckpt"))[0]
renamed_state_dict = OrderedDict()
state_dict = torch.load(weights_path)["state_dict"]
for key in state_dict.keys():
if "network" in key:
new_key = key.removeprefix("network.")
renamed_state_dict[new_key] = state_dict[key]
del state_dict
return renamed_state_dict
def _setup_encoder(self, pretrained_encoder_path: str,) -> None:
"""Load encoder module."""
path = Path(__file__).resolve().parents[2] / pretrained_encoder_path
with open(path / "config.yaml") as f:
cfg = OmegaConf.load(f)
state_dict = self._load_state_dict(path)
self.encoder = instantiate(cfg.network)
self.encoder.load_state_dict(state_dict)
del self.encoder.decoder
def _encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
z_e = self.encoder.encode(x)
z_q, commitment_loss = self.encoder.quantize(z_e)
z = self.encoder.post_codebook_conv(z_q)
return z, commitment_loss
def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
"""Encodes an image into a discrete (VQ) latent representation.
Args:
x (Tensor): Image tensor.
Shape:
- x: :math: `(B, C, H, W)`
- z: :math: `(B, Sx, E)`
where Sx is the length of the flattened feature maps projected from
the encoder. E latent dimension for each pixel in the projected
feature maps.
Returns:
Tensor: A Latent embedding of the image.
"""
if self.no_grad:
with torch.no_grad():
z_q, commitment_loss = self._encode(x)
else:
z_q, commitment_loss = self._encode(x)
z = self.latent_encoder(z_q)
# Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
z = z.permute(0, 2, 1)
return z, commitment_loss
def forward(self, x: Tensor, context: Tensor) -> Tensor:
"""Encodes images into word piece logtis.
Args:
x (Tensor): Input image(s).
context (Tensor): Target word embeddings.
Shapes:
- x: :math: `(B, C, H, W)`
- context: :math: `(B, Sy, T)`
where B is the batch size, C is the number of input channels, H is
the image height and W is the image width.
Returns:
Tensor: Sequence of logits.
"""
z, commitment_loss = self.encode(x)
logits = self.decode(z, context)
return logits, commitment_loss
|