summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vq_transformer.py
blob: 69f68fd187901e277678bc713141463d35561209 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
"""Vector quantized encoder, transformer decoder."""
from pathlib import Path
from typing import Tuple, Optional

import torch
from torch import Tensor

from text_recognizer.networks.vqvae.vqvae import VQVAE
from text_recognizer.networks.conv_transformer import ConvTransformer
from text_recognizer.networks.transformer.layers import Decoder


class VqTransformer(ConvTransformer):
    """Convolutional encoder and transformer decoder network."""

    def __init__(
        self,
        input_dims: Tuple[int, int, int],
        encoder_dim: int,
        hidden_dim: int,
        dropout_rate: float,
        num_classes: int,
        pad_index: Tensor,
        encoder: VQVAE,
        decoder: Decoder,
        no_grad: bool,
        pretrained_encoder_path: Optional[str] = None,
    ) -> None:
        super().__init__(
            input_dims=input_dims,
            encoder_dim=encoder_dim,
            hidden_dim=hidden_dim,
            dropout_rate=dropout_rate,
            num_classes=num_classes,
            pad_index=pad_index,
            encoder=encoder,
            decoder=decoder,
        )
        # For typing
        self.encoder: VQVAE

        self.no_grad = no_grad

        if pretrained_encoder_path is not None:
            self.pretrained_encoder_path = (
                Path(__file__).resolve().parents[2] / pretrained_encoder_path
            )
            self._setup_encoder()
        else:
            self.pretrained_encoder_path = None

    def _load_pretrained_encoder(self) -> None:
        self.encoder.load_state_dict(
            torch.load(self.pretrained_encoder_path)["state_dict"]["network"]
        )

    def _setup_encoder(self) -> None:
        """Remove unecessary layers."""
        self._load_pretrained_encoder()
        del self.encoder.decoder
        # del self.encoder.post_codebook_conv

    def _encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        z_e = self.encoder.encode(x)
        z_q, commitment_loss = self.encoder.quantize(z_e)
        return z_q, commitment_loss

    def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        """Encodes an image into a discrete (VQ) latent representation.

        Args:
            x (Tensor): Image tensor.

        Shape:
            - x: :math: `(B, C, H, W)`
            - z: :math: `(B, Sx, E)`

            where Sx is the length of the flattened feature maps projected from
            the encoder. E latent dimension for each pixel in the projected
            feature maps.

        Returns:
            Tensor: A Latent embedding of the image.
        """
        if self.no_grad:
            with torch.no_grad():
                z_q, commitment_loss = self._encode(x)
        else:
            z_q, commitment_loss = self._encode(x)

        z = self.latent_encoder(z_q)

        # Permute tensor from [B, E, Ho * Wo] to [B, Sx, E]
        z = z.permute(0, 2, 1)
        return z, commitment_loss

    def forward(self, x: Tensor, context: Tensor) -> Tensor:
        """Encodes images into word piece logtis.

        Args:
            x (Tensor): Input image(s).
            context (Tensor): Target word embeddings.

        Shapes:
            - x: :math: `(B, C, H, W)`
            - context: :math: `(B, Sy, T)`

            where B is the batch size, C is the number of input channels, H is
            the image height and W is the image width.

        Returns:
            Tensor: Sequence of logits.
        """
        z, commitment_loss = self.encode(x)
        logits = self.decode(z, context)
        return logits, commitment_loss