summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vq_transformer.py
blob: 2121c3363f38da493df42bc0bb37a1e67e1c7112 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
"""Vector quantized encoder, transformer decoder."""
from pathlib import Path
from typing import OrderedDict, Tuple, Union

from omegaconf import OmegaConf
from hydra.utils import instantiate
import torch
from torch import Tensor

from text_recognizer.networks.vqvae.vqvae import VQVAE
from text_recognizer.networks.conv_transformer import ConvTransformer
from text_recognizer.networks.transformer.layers import Decoder


class VqTransformer(ConvTransformer):
    """Convolutional encoder and transformer decoder network."""

    def __init__(
        self,
        input_dims: Tuple[int, int, int],
        encoder_dim: int,
        hidden_dim: int,
        dropout_rate: float,
        num_classes: int,
        pad_index: Tensor,
        decoder: Decoder,
        no_grad: bool,
        pretrained_encoder_path: str,
    ) -> None:
        # For typing
        self.encoder: VQVAE = None
        self.no_grad = no_grad

        super().__init__(
            input_dims=input_dims,
            encoder_dim=encoder_dim,
            hidden_dim=hidden_dim,
            dropout_rate=dropout_rate,
            num_classes=num_classes,
            pad_index=pad_index,
            encoder=self.encoder,
            decoder=decoder,
        )
        self._setup_encoder(pretrained_encoder_path)

    def _load_state_dict(self, path: Path) -> OrderedDict:
        weights_path = list((path / "checkpoints").glob("epoch=*.ckpt"))[0]
        renamed_state_dict = OrderedDict()
        state_dict = torch.load(weights_path)["state_dict"]
        for key in state_dict.keys():
            if "network" in key:
                new_key = key.removeprefix("network.")
                renamed_state_dict[new_key] = state_dict[key]
        del state_dict
        return renamed_state_dict

    def _setup_encoder(self, pretrained_encoder_path: str,) -> None:
        """Load encoder module."""
        path = Path(__file__).resolve().parents[2] / pretrained_encoder_path
        with open(path / "config.yaml") as f:
            cfg = OmegaConf.load(f)
        state_dict = self._load_state_dict(path)
        self.encoder = instantiate(cfg.network)
        self.encoder.load_state_dict(state_dict)
        del self.encoder.decoder

    def _encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        z_e = self.encoder.encode(x)
        z_q, commitment_loss = self.encoder.quantize(z_e)
        z = self.encoder.post_codebook_conv(z_q)
        return z, commitment_loss

    def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        """Encodes an image into a discrete (VQ) latent representation.

        Args:
            x (Tensor): Image tensor.

        Shape:
            - x: :math: `(B, C, H, W)`
            - z: :math: `(B, Sx, E)`

            where Sx is the length of the flattened feature maps projected from
            the encoder. E latent dimension for each pixel in the projected
            feature maps.

        Returns:
            Tensor: A Latent embedding of the image.
        """
        if self.no_grad:
            with torch.no_grad():
                z_q, commitment_loss = self._encode(x)
        else:
            z_q, commitment_loss = self._encode(x)

        z = self.latent_encoder(z_q)

        # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
        z = z.permute(0, 2, 1)
        return z, commitment_loss

    def forward(self, x: Tensor, context: Tensor) -> Tensor:
        """Encodes images into word piece logtis.

        Args:
            x (Tensor): Input image(s).
            context (Tensor): Target word embeddings.

        Shapes:
            - x: :math: `(B, C, H, W)`
            - context: :math: `(B, Sy, T)`

            where B is the batch size, C is the number of input channels, H is
            the image height and W is the image width.

        Returns:
            Tensor: Sequence of logits.
        """
        z, commitment_loss = self.encode(x)
        logits = self.decode(z, context)
        return logits, commitment_loss