diff options
author | aktersnurra <gustaf.rydholm@gmail.com> | 2021-01-24 22:14:17 +0100 |
---|---|---|
committer | aktersnurra <gustaf.rydholm@gmail.com> | 2021-01-24 22:14:17 +0100 |
commit | 4a54d7e690897dd6e6c719fb908fd371a44c2952 (patch) | |
tree | 04722ac94b9c3960baa5db7939d7ef01dbf535a6 /src/text_recognizer/networks/vqvae | |
parent | d691b548cd0b6fc4ea184d64261f633789fee021 (diff) |
Many updates, cool stuff on the way.
Diffstat (limited to 'src/text_recognizer/networks/vqvae')
-rw-r--r-- | src/text_recognizer/networks/vqvae/__init__.py | 4 | ||||
-rw-r--r-- | src/text_recognizer/networks/vqvae/decoder.py | 133 | ||||
-rw-r--r-- | src/text_recognizer/networks/vqvae/encoder.py | 125 | ||||
-rw-r--r-- | src/text_recognizer/networks/vqvae/vector_quantizer.py | 2 | ||||
-rw-r--r-- | src/text_recognizer/networks/vqvae/vqvae.py | 74 |
5 files changed, 316 insertions, 22 deletions
diff --git a/src/text_recognizer/networks/vqvae/__init__.py b/src/text_recognizer/networks/vqvae/__init__.py index e1f05fa..763953c 100644 --- a/src/text_recognizer/networks/vqvae/__init__.py +++ b/src/text_recognizer/networks/vqvae/__init__.py @@ -1 +1,5 @@ """VQ-VAE module.""" +from .decoder import Decoder +from .encoder import Encoder +from .vector_quantizer import VectorQuantizer +from .vqvae import VQVAE diff --git a/src/text_recognizer/networks/vqvae/decoder.py b/src/text_recognizer/networks/vqvae/decoder.py new file mode 100644 index 0000000..8847aba --- /dev/null +++ b/src/text_recognizer/networks/vqvae/decoder.py @@ -0,0 +1,133 @@ +"""CNN decoder for the VQ-VAE.""" + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.util import activation_function +from text_recognizer.networks.vqvae.encoder import _ResidualBlock + + +class Decoder(nn.Module): + """A CNN encoder network.""" + + def __init__( + self, + channels: List[int], + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + embedding_dim: int, + upsampling: Optional[List[List[int]]] = None, + activation: str = "leaky_relu", + dropout_rate: float = 0.0, + ) -> None: + super().__init__() + + if dropout_rate: + if activation == "selu": + dropout = nn.AlphaDropout(p=dropout_rate) + else: + dropout = nn.Dropout(p=dropout_rate) + else: + dropout = None + + self.upsampling = upsampling + + self.res_block = nn.ModuleList([]) + self.upsampling_block = nn.ModuleList([]) + + self.embedding_dim = embedding_dim + activation = activation_function(activation) + + # Configure encoder. + self.decoder = self._build_decoder( + channels, kernel_sizes, strides, num_residual_layers, activation, dropout, + ) + + def _build_decompression_block( + self, + in_channels: int, + channels: int, + kernel_sizes: List[int], + strides: List[int], + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> nn.ModuleList: + modules = nn.ModuleList([]) + configuration = zip(channels, kernel_sizes, strides) + for i, (out_channels, kernel_size, stride) in enumerate(configuration): + modules.append( + nn.Sequential( + nn.ConvTranspose2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=1, + ), + activation, + ) + ) + + if i < len(self.upsampling): + modules.append(nn.Upsample(size=self.upsampling[i]),) + + if dropout is not None: + modules.append(dropout) + + in_channels = out_channels + + modules.extend( + nn.Sequential( + nn.ConvTranspose2d( + in_channels, 1, kernel_size=kernel_size, stride=stride, padding=1 + ), + nn.Tanh(), + ) + ) + + return modules + + def _build_decoder( + self, + channels: int, + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> nn.Sequential: + + self.res_block.append( + nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,) + ) + + # Bottleneck module. + self.res_block.extend( + nn.ModuleList( + [ + _ResidualBlock(channels[0], channels[0], dropout) + for i in range(num_residual_layers) + ] + ) + ) + + # Decompression module + self.upsampling_block.extend( + self._build_decompression_block( + channels[0], channels[1:], kernel_sizes, strides, activation, dropout + ) + ) + + self.res_block = nn.Sequential(*self.res_block) + self.upsampling_block = nn.Sequential(*self.upsampling_block) + + return nn.Sequential(self.res_block, self.upsampling_block) + + def forward(self, z_q: Tensor) -> Tensor: + """Reconstruct input from given codes.""" + x_reconstruction = self.decoder(z_q) + return x_reconstruction diff --git a/src/text_recognizer/networks/vqvae/encoder.py b/src/text_recognizer/networks/vqvae/encoder.py index 60c4c43..d3adac5 100644 --- a/src/text_recognizer/networks/vqvae/encoder.py +++ b/src/text_recognizer/networks/vqvae/encoder.py @@ -1,6 +1,5 @@ """CNN encoder for the VQ-VAE.""" - -from typing import List, Optional, Type +from typing import List, Optional, Tuple, Type import torch from torch import nn @@ -12,16 +11,12 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer class _ResidualBlock(nn.Module): def __init__( - self, - in_channels: int, - out_channels: int, - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], + self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]], ) -> None: super().__init__() self.block = [ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), - activation, + nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False), ] @@ -42,23 +37,111 @@ class Encoder(nn.Module): self, in_channels: int, channels: List[int], + kernel_sizes: List[int], + strides: List[int], num_residual_layers: int, embedding_dim: int, num_embeddings: int, beta: float = 0.25, - activation: str = "elu", + activation: str = "leaky_relu", dropout_rate: float = 0.0, ) -> None: super().__init__() - pass - # if dropout_rate: - # if activation == "selu": - # dropout = nn.AlphaDropout(p=dropout_rate) - # else: - # dropout = nn.Dropout(p=dropout_rate) - # else: - # dropout = None - - def _build_encoder(self) -> nn.Sequential: - # TODO: Continue to implement encoder. - pass + + if dropout_rate: + if activation == "selu": + dropout = nn.AlphaDropout(p=dropout_rate) + else: + dropout = nn.Dropout(p=dropout_rate) + else: + dropout = None + + self.embedding_dim = embedding_dim + self.num_embeddings = num_embeddings + self.beta = beta + activation = activation_function(activation) + + # Configure encoder. + self.encoder = self._build_encoder( + in_channels, + channels, + kernel_sizes, + strides, + num_residual_layers, + activation, + dropout, + ) + + # Configure Vector Quantizer. + self.vector_quantizer = VectorQuantizer( + self.num_embeddings, self.embedding_dim, self.beta + ) + + def _build_compression_block( + self, + in_channels: int, + channels: int, + kernel_sizes: List[int], + strides: List[int], + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> nn.ModuleList: + modules = nn.ModuleList([]) + configuration = zip(channels, kernel_sizes, strides) + for out_channels, kernel_size, stride in configuration: + modules.append( + nn.Sequential( + nn.Conv2d( + in_channels, out_channels, kernel_size, stride=stride, padding=1 + ), + activation, + ) + ) + + if dropout is not None: + modules.append(dropout) + + in_channels = out_channels + + return modules + + def _build_encoder( + self, + in_channels: int, + channels: int, + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + activation: Type[nn.Module], + dropout: Optional[Type[nn.Module]], + ) -> nn.Sequential: + encoder = nn.ModuleList([]) + + # compression module + encoder.extend( + self._build_compression_block( + in_channels, channels, kernel_sizes, strides, activation, dropout + ) + ) + + # Bottleneck module. + encoder.extend( + nn.ModuleList( + [ + _ResidualBlock(channels[-1], channels[-1], dropout) + for i in range(num_residual_layers) + ] + ) + ) + + encoder.append( + nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,) + ) + + return nn.Sequential(*encoder) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Encodes input into a discrete representation.""" + z_e = self.encoder(x) + z_q, vq_loss = self.vector_quantizer(z_e) + return z_q, vq_loss diff --git a/src/text_recognizer/networks/vqvae/vector_quantizer.py b/src/text_recognizer/networks/vqvae/vector_quantizer.py index 25e5583..f92c7ee 100644 --- a/src/text_recognizer/networks/vqvae/vector_quantizer.py +++ b/src/text_recognizer/networks/vqvae/vector_quantizer.py @@ -26,7 +26,7 @@ class VectorQuantizer(nn.Module): self.embedding = nn.Embedding(self.K, self.D) # Initialize the codebook. - self.embedding.weight.uniform_(-1 / self.K, 1 / self.K) + nn.init.uniform_(self.embedding.weight, -1 / self.K, 1 / self.K) def discretization_bottleneck(self, latent: Tensor) -> Tensor: """Computes the code nearest to the latent representation. diff --git a/src/text_recognizer/networks/vqvae/vqvae.py b/src/text_recognizer/networks/vqvae/vqvae.py new file mode 100644 index 0000000..50448b4 --- /dev/null +++ b/src/text_recognizer/networks/vqvae/vqvae.py @@ -0,0 +1,74 @@ +"""The VQ-VAE.""" + +from typing import List, Optional, Tuple, Type + +import torch +from torch import nn +from torch import Tensor + +from text_recognizer.networks.vqvae import Decoder, Encoder + + +class VQVAE(nn.Module): + """Vector Quantized Variational AutoEncoder.""" + + def __init__( + self, + in_channels: int, + channels: List[int], + kernel_sizes: List[int], + strides: List[int], + num_residual_layers: int, + embedding_dim: int, + num_embeddings: int, + upsampling: Optional[List[List[int]]] = None, + beta: float = 0.25, + activation: str = "leaky_relu", + dropout_rate: float = 0.0, + ) -> None: + super().__init__() + + # configure encoder. + self.encoder = Encoder( + in_channels, + channels, + kernel_sizes, + strides, + num_residual_layers, + embedding_dim, + num_embeddings, + beta, + activation, + dropout_rate, + ) + + # Configure decoder. + channels.reverse() + kernel_sizes.reverse() + strides.reverse() + self.decoder = Decoder( + channels, + kernel_sizes, + strides, + num_residual_layers, + embedding_dim, + upsampling, + activation, + dropout_rate, + ) + + def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Encodes input to a latent code.""" + return self.encoder(x) + + def decode(self, z_q: Tensor) -> Tensor: + """Reconstructs input from latent codes.""" + return self.decoder(z_q) + + def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """Compresses and decompresses input.""" + if len(x.shape) < 4: + x = x[(None,) * (4 - len(x.shape))] + z_q, vq_loss = self.encode(x) + x_reconstruction = self.decode(z_q) + return x_reconstruction, vq_loss |