summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--text_recognizer/networks/vqvae/pixelcnn.py161
-rw-r--r--training/conf/network/decoder/pixelcnn_decoder.yaml5
-rw-r--r--training/conf/network/encoder/pixelcnn_encoder.yaml5
-rw-r--r--training/conf/network/vqvae_pixelcnn.yaml9
4 files changed, 0 insertions, 180 deletions
diff --git a/text_recognizer/networks/vqvae/pixelcnn.py b/text_recognizer/networks/vqvae/pixelcnn.py
deleted file mode 100644
index b9e6080..0000000
--- a/text_recognizer/networks/vqvae/pixelcnn.py
+++ /dev/null
@@ -1,161 +0,0 @@
-"""PixelCNN encoder and decoder.
-
-Same as in VQGAN among other. Hopefully, better reconstructions...
-
-TODO: Add num of residual layers.
-"""
-from typing import Sequence
-
-from torch import nn
-from torch import Tensor
-
-from text_recognizer.networks.vqvae.attention import Attention
-from text_recognizer.networks.vqvae.norm import Normalize
-from text_recognizer.networks.vqvae.residual import Residual
-from text_recognizer.networks.vqvae.resize import Downsample, Upsample
-
-
-class Encoder(nn.Module):
- """PixelCNN encoder."""
-
- def __init__(
- self,
- in_channels: int,
- hidden_dim: int,
- channels_multipliers: Sequence[int],
- dropout_rate: float,
- ) -> None:
- super().__init__()
- self.in_channels = in_channels
- self.dropout_rate = dropout_rate
- self.hidden_dim = hidden_dim
- self.channels_multipliers = tuple(channels_multipliers)
- self.encoder = self._build_encoder()
-
- def _build_encoder(self) -> nn.Sequential:
- """Builds encoder."""
- encoder = [
- nn.Conv2d(
- in_channels=self.in_channels,
- out_channels=self.hidden_dim,
- kernel_size=3,
- stride=1,
- padding=1,
- ),
- ]
- num_blocks = len(self.channels_multipliers)
- in_channels_multipliers = (1,) + self.channels_multipliers
- for i in range(num_blocks):
- in_channels = self.hidden_dim * in_channels_multipliers[i]
- out_channels = self.hidden_dim * self.channels_multipliers[i]
- encoder += [
- Residual(
- in_channels=in_channels,
- out_channels=out_channels,
- dropout_rate=self.dropout_rate,
- use_norm=True,
- ),
- ]
- if i == num_blocks - 1:
- encoder.append(Attention(in_channels=out_channels))
- encoder.append(Downsample())
-
- for _ in range(2):
- encoder += [
- Residual(
- in_channels=self.hidden_dim * self.channels_multipliers[-1],
- out_channels=self.hidden_dim * self.channels_multipliers[-1],
- dropout_rate=self.dropout_rate,
- use_norm=True,
- ),
- Attention(in_channels=self.hidden_dim * self.channels_multipliers[-1]),
- ]
-
- encoder += [
- Normalize(num_channels=self.hidden_dim * self.channels_multipliers[-1]),
- nn.Mish(),
- nn.Conv2d(
- in_channels=self.hidden_dim * self.channels_multipliers[-1],
- out_channels=self.hidden_dim * self.channels_multipliers[-1],
- kernel_size=3,
- stride=1,
- padding=1,
- ),
- ]
- return nn.Sequential(*encoder)
-
- def forward(self, x: Tensor) -> Tensor:
- """Encodes input to a latent representation."""
- return self.encoder(x)
-
-
-class Decoder(nn.Module):
- """PixelCNN decoder."""
-
- def __init__(
- self,
- hidden_dim: int,
- channels_multipliers: Sequence[int],
- out_channels: int,
- dropout_rate: float,
- ) -> None:
- super().__init__()
- self.hidden_dim = hidden_dim
- self.out_channels = out_channels
- self.channels_multipliers = tuple(channels_multipliers)
- self.dropout_rate = dropout_rate
- self.decoder = self._build_decoder()
-
- def _build_decoder(self) -> nn.Sequential:
- """Builds decoder."""
- in_channels = self.hidden_dim * self.channels_multipliers[0]
- decoder = [
- Residual(
- in_channels=in_channels,
- out_channels=in_channels,
- dropout_rate=self.dropout_rate,
- use_norm=True,
- ),
- Attention(in_channels=in_channels),
- Residual(
- in_channels=in_channels,
- out_channels=in_channels,
- dropout_rate=self.dropout_rate,
- use_norm=True,
- ),
- ]
-
- out_channels_multipliers = self.channels_multipliers + (1,)
- num_blocks = len(self.channels_multipliers)
-
- for i in range(num_blocks):
- in_channels = self.hidden_dim * self.channels_multipliers[i]
- out_channels = self.hidden_dim * out_channels_multipliers[i + 1]
- decoder.append(
- Residual(
- in_channels=in_channels,
- out_channels=out_channels,
- dropout_rate=self.dropout_rate,
- use_norm=True,
- )
- )
- if i == 0:
- decoder.append(Attention(in_channels=out_channels))
- decoder.append(Upsample())
-
- decoder += [
- Normalize(num_channels=self.hidden_dim * out_channels_multipliers[-1]),
- nn.Mish(),
- nn.Conv2d(
- in_channels=self.hidden_dim * out_channels_multipliers[-1],
- out_channels=self.out_channels,
- kernel_size=3,
- stride=1,
- padding=1,
- ),
- ]
- return nn.Sequential(*decoder)
-
- def forward(self, x: Tensor) -> Tensor:
- """Decodes latent vector."""
- return self.decoder(x)
diff --git a/training/conf/network/decoder/pixelcnn_decoder.yaml b/training/conf/network/decoder/pixelcnn_decoder.yaml
deleted file mode 100644
index cdddb7a..0000000
--- a/training/conf/network/decoder/pixelcnn_decoder.yaml
+++ /dev/null
@@ -1,5 +0,0 @@
-_target_: text_recognizer.networks.vqvae.pixelcnn.Decoder
-out_channels: 1
-hidden_dim: 8
-channels_multipliers: [8, 2, 1]
-dropout_rate: 0.25
diff --git a/training/conf/network/encoder/pixelcnn_encoder.yaml b/training/conf/network/encoder/pixelcnn_encoder.yaml
deleted file mode 100644
index f12957b..0000000
--- a/training/conf/network/encoder/pixelcnn_encoder.yaml
+++ /dev/null
@@ -1,5 +0,0 @@
-_target_: text_recognizer.networks.vqvae.pixelcnn.Encoder
-in_channels: 1
-hidden_dim: 8
-channels_multipliers: [1, 2, 8]
-dropout_rate: 0.25
diff --git a/training/conf/network/vqvae_pixelcnn.yaml b/training/conf/network/vqvae_pixelcnn.yaml
deleted file mode 100644
index cd850af..0000000
--- a/training/conf/network/vqvae_pixelcnn.yaml
+++ /dev/null
@@ -1,9 +0,0 @@
-defaults:
- - encoder: pixelcnn_encoder
- - decoder: pixelcnn_decoder
-
-_target_: text_recognizer.networks.vqvae.vqvae.VQVAE
-hidden_dim: 64
-embedding_dim: 32
-num_embeddings: 1024
-decay: 0.99