diff options
Diffstat (limited to 'text_recognizer/networks/vqvae/pixelcnn.py')
-rw-r--r-- | text_recognizer/networks/vqvae/pixelcnn.py | 165 |
1 files changed, 165 insertions, 0 deletions
diff --git a/text_recognizer/networks/vqvae/pixelcnn.py b/text_recognizer/networks/vqvae/pixelcnn.py new file mode 100644 index 0000000..5c580df --- /dev/null +++ b/text_recognizer/networks/vqvae/pixelcnn.py @@ -0,0 +1,165 @@ +"""PixelCNN encoder and decoder. + +Same as in VQGAN among other. Hopefully, better reconstructions... + +TODO: Add num of residual layers. +""" +from typing import Sequence + +from torch import nn +from torch import Tensor + +from text_recognizer.networks.vqvae.attention import Attention +from text_recognizer.networks.vqvae.norm import Normalize +from text_recognizer.networks.vqvae.residual import Residual +from text_recognizer.networks.vqvae.resize import Downsample, Upsample + + +class Encoder(nn.Module): + """PixelCNN encoder.""" + + def __init__( + self, + in_channels: int, + hidden_dim: int, + channels_multipliers: Sequence[int], + dropout_rate: float, + ) -> None: + super().__init__() + self.in_channels = in_channels + self.dropout_rate = dropout_rate + self.hidden_dim = hidden_dim + self.channels_multipliers = tuple(channels_multipliers) + self.encoder = self._build_encoder() + + def _build_encoder(self) -> nn.Sequential: + """Builds encoder.""" + encoder = [ + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.hidden_dim, + kernel_size=3, + stride=1, + padding=1, + ), + ] + num_blocks = len(self.channels_multipliers) + in_channels_multipliers = (1,) + self.channels_multipliers + for i in range(num_blocks): + in_channels = self.hidden_dim * in_channels_multipliers[i] + out_channels = self.hidden_dim * self.channels_multipliers[i] + encoder += [ + Residual( + in_channels=in_channels, + out_channels=out_channels, + dropout_rate=self.dropout_rate, + use_norm=True, + ), + ] + if i == num_blocks - 1: + encoder.append(Attention(in_channels=out_channels)) + encoder.append(Downsample()) + + for _ in range(2): + encoder += [ + Residual( + in_channels=self.hidden_dim * self.channels_multipliers[-1], + out_channels=self.hidden_dim * self.channels_multipliers[-1], + dropout_rate=self.dropout_rate, + use_norm=True, + ), + Attention(in_channels=self.hidden_dim * self.channels_multipliers[-1]) + ] + + encoder += [ + Normalize(num_channels=self.hidden_dim * self.channels_multipliers[-1]), + nn.Mish(), + nn.Conv2d( + in_channels=self.hidden_dim * self.channels_multipliers[-1], + out_channels=self.hidden_dim * self.channels_multipliers[-1], + kernel_size=3, + stride=1, + padding=1, + ), + ] + return nn.Sequential(*encoder) + + def forward(self, x: Tensor) -> Tensor: + """Encodes input to a latent representation.""" + return self.encoder(x) + + +class Decoder(nn.Module): + """PixelCNN decoder.""" + + def __init__( + self, + hidden_dim: int, + channels_multipliers: Sequence[int], + out_channels: int, + dropout_rate: float, + ) -> None: + super().__init__() + self.hidden_dim = hidden_dim + self.out_channels = out_channels + self.channels_multipliers = tuple(channels_multipliers) + self.dropout_rate = dropout_rate + self.decoder = self._build_decoder() + + def _build_decoder(self) -> nn.Sequential: + """Builds decoder.""" + in_channels = self.hidden_dim * self.channels_multipliers[0] + decoder = [ + Residual( + in_channels=in_channels, + out_channels=in_channels, + dropout_rate=self.dropout_rate, + use_norm=True, + ), + Attention(in_channels=in_channels), + Residual( + in_channels=in_channels, + out_channels=in_channels, + dropout_rate=self.dropout_rate, + use_norm=True, + ), + ] + + out_channels_multipliers = self.channels_multipliers + (1, ) + num_blocks = len(self.channels_multipliers) + + for i in range(num_blocks): + in_channels = self.hidden_dim * self.channels_multipliers[i] + out_channels = self.hidden_dim * out_channels_multipliers[i + 1] + decoder.append( + Residual( + in_channels=in_channels, + out_channels=out_channels, + dropout_rate=self.dropout_rate, + use_norm=True, + ) + ) + if i == 0: + decoder.append( + Attention( + in_channels=out_channels + ) + ) + decoder.append(Upsample()) + + decoder += [ + Normalize(num_channels=self.hidden_dim * out_channels_multipliers[-1]), + nn.Mish(), + nn.Conv2d( + in_channels=self.hidden_dim * out_channels_multipliers[-1], + out_channels=self.out_channels, + kernel_size=3, + stride=1, + padding=1, + ), + ] + return nn.Sequential(*decoder) + + def forward(self, x: Tensor) -> Tensor: + """Decodes latent vector.""" + return self.decoder(x) |