summaryrefslogtreecommitdiff
path: root/text_recognizer/networks/vqvae/pixelcnn.py
diff options
context:
space:
mode:
Diffstat (limited to 'text_recognizer/networks/vqvae/pixelcnn.py')
-rw-r--r--text_recognizer/networks/vqvae/pixelcnn.py165
1 files changed, 165 insertions, 0 deletions
diff --git a/text_recognizer/networks/vqvae/pixelcnn.py b/text_recognizer/networks/vqvae/pixelcnn.py
new file mode 100644
index 0000000..5c580df
--- /dev/null
+++ b/text_recognizer/networks/vqvae/pixelcnn.py
@@ -0,0 +1,165 @@
+"""PixelCNN encoder and decoder.
+
+Same as in VQGAN among other. Hopefully, better reconstructions...
+
+TODO: Add num of residual layers.
+"""
+from typing import Sequence
+
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.vqvae.attention import Attention
+from text_recognizer.networks.vqvae.norm import Normalize
+from text_recognizer.networks.vqvae.residual import Residual
+from text_recognizer.networks.vqvae.resize import Downsample, Upsample
+
+
+class Encoder(nn.Module):
+ """PixelCNN encoder."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ hidden_dim: int,
+ channels_multipliers: Sequence[int],
+ dropout_rate: float,
+ ) -> None:
+ super().__init__()
+ self.in_channels = in_channels
+ self.dropout_rate = dropout_rate
+ self.hidden_dim = hidden_dim
+ self.channels_multipliers = tuple(channels_multipliers)
+ self.encoder = self._build_encoder()
+
+ def _build_encoder(self) -> nn.Sequential:
+ """Builds encoder."""
+ encoder = [
+ nn.Conv2d(
+ in_channels=self.in_channels,
+ out_channels=self.hidden_dim,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ ),
+ ]
+ num_blocks = len(self.channels_multipliers)
+ in_channels_multipliers = (1,) + self.channels_multipliers
+ for i in range(num_blocks):
+ in_channels = self.hidden_dim * in_channels_multipliers[i]
+ out_channels = self.hidden_dim * self.channels_multipliers[i]
+ encoder += [
+ Residual(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout_rate=self.dropout_rate,
+ use_norm=True,
+ ),
+ ]
+ if i == num_blocks - 1:
+ encoder.append(Attention(in_channels=out_channels))
+ encoder.append(Downsample())
+
+ for _ in range(2):
+ encoder += [
+ Residual(
+ in_channels=self.hidden_dim * self.channels_multipliers[-1],
+ out_channels=self.hidden_dim * self.channels_multipliers[-1],
+ dropout_rate=self.dropout_rate,
+ use_norm=True,
+ ),
+ Attention(in_channels=self.hidden_dim * self.channels_multipliers[-1])
+ ]
+
+ encoder += [
+ Normalize(num_channels=self.hidden_dim * self.channels_multipliers[-1]),
+ nn.Mish(),
+ nn.Conv2d(
+ in_channels=self.hidden_dim * self.channels_multipliers[-1],
+ out_channels=self.hidden_dim * self.channels_multipliers[-1],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ ),
+ ]
+ return nn.Sequential(*encoder)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Encodes input to a latent representation."""
+ return self.encoder(x)
+
+
+class Decoder(nn.Module):
+ """PixelCNN decoder."""
+
+ def __init__(
+ self,
+ hidden_dim: int,
+ channels_multipliers: Sequence[int],
+ out_channels: int,
+ dropout_rate: float,
+ ) -> None:
+ super().__init__()
+ self.hidden_dim = hidden_dim
+ self.out_channels = out_channels
+ self.channels_multipliers = tuple(channels_multipliers)
+ self.dropout_rate = dropout_rate
+ self.decoder = self._build_decoder()
+
+ def _build_decoder(self) -> nn.Sequential:
+ """Builds decoder."""
+ in_channels = self.hidden_dim * self.channels_multipliers[0]
+ decoder = [
+ Residual(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ dropout_rate=self.dropout_rate,
+ use_norm=True,
+ ),
+ Attention(in_channels=in_channels),
+ Residual(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ dropout_rate=self.dropout_rate,
+ use_norm=True,
+ ),
+ ]
+
+ out_channels_multipliers = self.channels_multipliers + (1, )
+ num_blocks = len(self.channels_multipliers)
+
+ for i in range(num_blocks):
+ in_channels = self.hidden_dim * self.channels_multipliers[i]
+ out_channels = self.hidden_dim * out_channels_multipliers[i + 1]
+ decoder.append(
+ Residual(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout_rate=self.dropout_rate,
+ use_norm=True,
+ )
+ )
+ if i == 0:
+ decoder.append(
+ Attention(
+ in_channels=out_channels
+ )
+ )
+ decoder.append(Upsample())
+
+ decoder += [
+ Normalize(num_channels=self.hidden_dim * out_channels_multipliers[-1]),
+ nn.Mish(),
+ nn.Conv2d(
+ in_channels=self.hidden_dim * out_channels_multipliers[-1],
+ out_channels=self.out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ ),
+ ]
+ return nn.Sequential(*decoder)
+
+ def forward(self, x: Tensor) -> Tensor:
+ """Decodes latent vector."""
+ return self.decoder(x)