summaryrefslogtreecommitdiff
path: root/src/text_recognizer/networks/vqvae
diff options
context:
space:
mode:
authoraktersnurra <gustaf.rydholm@gmail.com>2021-01-24 22:14:17 +0100
committeraktersnurra <gustaf.rydholm@gmail.com>2021-01-24 22:14:17 +0100
commit4a54d7e690897dd6e6c719fb908fd371a44c2952 (patch)
tree04722ac94b9c3960baa5db7939d7ef01dbf535a6 /src/text_recognizer/networks/vqvae
parentd691b548cd0b6fc4ea184d64261f633789fee021 (diff)
Many updates, cool stuff on the way.
Diffstat (limited to 'src/text_recognizer/networks/vqvae')
-rw-r--r--src/text_recognizer/networks/vqvae/__init__.py4
-rw-r--r--src/text_recognizer/networks/vqvae/decoder.py133
-rw-r--r--src/text_recognizer/networks/vqvae/encoder.py125
-rw-r--r--src/text_recognizer/networks/vqvae/vector_quantizer.py2
-rw-r--r--src/text_recognizer/networks/vqvae/vqvae.py74
5 files changed, 316 insertions, 22 deletions
diff --git a/src/text_recognizer/networks/vqvae/__init__.py b/src/text_recognizer/networks/vqvae/__init__.py
index e1f05fa..763953c 100644
--- a/src/text_recognizer/networks/vqvae/__init__.py
+++ b/src/text_recognizer/networks/vqvae/__init__.py
@@ -1 +1,5 @@
"""VQ-VAE module."""
+from .decoder import Decoder
+from .encoder import Encoder
+from .vector_quantizer import VectorQuantizer
+from .vqvae import VQVAE
diff --git a/src/text_recognizer/networks/vqvae/decoder.py b/src/text_recognizer/networks/vqvae/decoder.py
new file mode 100644
index 0000000..8847aba
--- /dev/null
+++ b/src/text_recognizer/networks/vqvae/decoder.py
@@ -0,0 +1,133 @@
+"""CNN decoder for the VQ-VAE."""
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.util import activation_function
+from text_recognizer.networks.vqvae.encoder import _ResidualBlock
+
+
+class Decoder(nn.Module):
+ """A CNN encoder network."""
+
+ def __init__(
+ self,
+ channels: List[int],
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ embedding_dim: int,
+ upsampling: Optional[List[List[int]]] = None,
+ activation: str = "leaky_relu",
+ dropout_rate: float = 0.0,
+ ) -> None:
+ super().__init__()
+
+ if dropout_rate:
+ if activation == "selu":
+ dropout = nn.AlphaDropout(p=dropout_rate)
+ else:
+ dropout = nn.Dropout(p=dropout_rate)
+ else:
+ dropout = None
+
+ self.upsampling = upsampling
+
+ self.res_block = nn.ModuleList([])
+ self.upsampling_block = nn.ModuleList([])
+
+ self.embedding_dim = embedding_dim
+ activation = activation_function(activation)
+
+ # Configure encoder.
+ self.decoder = self._build_decoder(
+ channels, kernel_sizes, strides, num_residual_layers, activation, dropout,
+ )
+
+ def _build_decompression_block(
+ self,
+ in_channels: int,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.ModuleList:
+ modules = nn.ModuleList([])
+ configuration = zip(channels, kernel_sizes, strides)
+ for i, (out_channels, kernel_size, stride) in enumerate(configuration):
+ modules.append(
+ nn.Sequential(
+ nn.ConvTranspose2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=1,
+ ),
+ activation,
+ )
+ )
+
+ if i < len(self.upsampling):
+ modules.append(nn.Upsample(size=self.upsampling[i]),)
+
+ if dropout is not None:
+ modules.append(dropout)
+
+ in_channels = out_channels
+
+ modules.extend(
+ nn.Sequential(
+ nn.ConvTranspose2d(
+ in_channels, 1, kernel_size=kernel_size, stride=stride, padding=1
+ ),
+ nn.Tanh(),
+ )
+ )
+
+ return modules
+
+ def _build_decoder(
+ self,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.Sequential:
+
+ self.res_block.append(
+ nn.Conv2d(self.embedding_dim, channels[0], kernel_size=1, stride=1,)
+ )
+
+ # Bottleneck module.
+ self.res_block.extend(
+ nn.ModuleList(
+ [
+ _ResidualBlock(channels[0], channels[0], dropout)
+ for i in range(num_residual_layers)
+ ]
+ )
+ )
+
+ # Decompression module
+ self.upsampling_block.extend(
+ self._build_decompression_block(
+ channels[0], channels[1:], kernel_sizes, strides, activation, dropout
+ )
+ )
+
+ self.res_block = nn.Sequential(*self.res_block)
+ self.upsampling_block = nn.Sequential(*self.upsampling_block)
+
+ return nn.Sequential(self.res_block, self.upsampling_block)
+
+ def forward(self, z_q: Tensor) -> Tensor:
+ """Reconstruct input from given codes."""
+ x_reconstruction = self.decoder(z_q)
+ return x_reconstruction
diff --git a/src/text_recognizer/networks/vqvae/encoder.py b/src/text_recognizer/networks/vqvae/encoder.py
index 60c4c43..d3adac5 100644
--- a/src/text_recognizer/networks/vqvae/encoder.py
+++ b/src/text_recognizer/networks/vqvae/encoder.py
@@ -1,6 +1,5 @@
"""CNN encoder for the VQ-VAE."""
-
-from typing import List, Optional, Type
+from typing import List, Optional, Tuple, Type
import torch
from torch import nn
@@ -12,16 +11,12 @@ from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer
class _ResidualBlock(nn.Module):
def __init__(
- self,
- in_channels: int,
- out_channels: int,
- activation: Type[nn.Module],
- dropout: Optional[Type[nn.Module]],
+ self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]],
) -> None:
super().__init__()
self.block = [
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
- activation,
+ nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False),
]
@@ -42,23 +37,111 @@ class Encoder(nn.Module):
self,
in_channels: int,
channels: List[int],
+ kernel_sizes: List[int],
+ strides: List[int],
num_residual_layers: int,
embedding_dim: int,
num_embeddings: int,
beta: float = 0.25,
- activation: str = "elu",
+ activation: str = "leaky_relu",
dropout_rate: float = 0.0,
) -> None:
super().__init__()
- pass
- # if dropout_rate:
- # if activation == "selu":
- # dropout = nn.AlphaDropout(p=dropout_rate)
- # else:
- # dropout = nn.Dropout(p=dropout_rate)
- # else:
- # dropout = None
-
- def _build_encoder(self) -> nn.Sequential:
- # TODO: Continue to implement encoder.
- pass
+
+ if dropout_rate:
+ if activation == "selu":
+ dropout = nn.AlphaDropout(p=dropout_rate)
+ else:
+ dropout = nn.Dropout(p=dropout_rate)
+ else:
+ dropout = None
+
+ self.embedding_dim = embedding_dim
+ self.num_embeddings = num_embeddings
+ self.beta = beta
+ activation = activation_function(activation)
+
+ # Configure encoder.
+ self.encoder = self._build_encoder(
+ in_channels,
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ activation,
+ dropout,
+ )
+
+ # Configure Vector Quantizer.
+ self.vector_quantizer = VectorQuantizer(
+ self.num_embeddings, self.embedding_dim, self.beta
+ )
+
+ def _build_compression_block(
+ self,
+ in_channels: int,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.ModuleList:
+ modules = nn.ModuleList([])
+ configuration = zip(channels, kernel_sizes, strides)
+ for out_channels, kernel_size, stride in configuration:
+ modules.append(
+ nn.Sequential(
+ nn.Conv2d(
+ in_channels, out_channels, kernel_size, stride=stride, padding=1
+ ),
+ activation,
+ )
+ )
+
+ if dropout is not None:
+ modules.append(dropout)
+
+ in_channels = out_channels
+
+ return modules
+
+ def _build_encoder(
+ self,
+ in_channels: int,
+ channels: int,
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ activation: Type[nn.Module],
+ dropout: Optional[Type[nn.Module]],
+ ) -> nn.Sequential:
+ encoder = nn.ModuleList([])
+
+ # compression module
+ encoder.extend(
+ self._build_compression_block(
+ in_channels, channels, kernel_sizes, strides, activation, dropout
+ )
+ )
+
+ # Bottleneck module.
+ encoder.extend(
+ nn.ModuleList(
+ [
+ _ResidualBlock(channels[-1], channels[-1], dropout)
+ for i in range(num_residual_layers)
+ ]
+ )
+ )
+
+ encoder.append(
+ nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,)
+ )
+
+ return nn.Sequential(*encoder)
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes input into a discrete representation."""
+ z_e = self.encoder(x)
+ z_q, vq_loss = self.vector_quantizer(z_e)
+ return z_q, vq_loss
diff --git a/src/text_recognizer/networks/vqvae/vector_quantizer.py b/src/text_recognizer/networks/vqvae/vector_quantizer.py
index 25e5583..f92c7ee 100644
--- a/src/text_recognizer/networks/vqvae/vector_quantizer.py
+++ b/src/text_recognizer/networks/vqvae/vector_quantizer.py
@@ -26,7 +26,7 @@ class VectorQuantizer(nn.Module):
self.embedding = nn.Embedding(self.K, self.D)
# Initialize the codebook.
- self.embedding.weight.uniform_(-1 / self.K, 1 / self.K)
+ nn.init.uniform_(self.embedding.weight, -1 / self.K, 1 / self.K)
def discretization_bottleneck(self, latent: Tensor) -> Tensor:
"""Computes the code nearest to the latent representation.
diff --git a/src/text_recognizer/networks/vqvae/vqvae.py b/src/text_recognizer/networks/vqvae/vqvae.py
new file mode 100644
index 0000000..50448b4
--- /dev/null
+++ b/src/text_recognizer/networks/vqvae/vqvae.py
@@ -0,0 +1,74 @@
+"""The VQ-VAE."""
+
+from typing import List, Optional, Tuple, Type
+
+import torch
+from torch import nn
+from torch import Tensor
+
+from text_recognizer.networks.vqvae import Decoder, Encoder
+
+
+class VQVAE(nn.Module):
+ """Vector Quantized Variational AutoEncoder."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ channels: List[int],
+ kernel_sizes: List[int],
+ strides: List[int],
+ num_residual_layers: int,
+ embedding_dim: int,
+ num_embeddings: int,
+ upsampling: Optional[List[List[int]]] = None,
+ beta: float = 0.25,
+ activation: str = "leaky_relu",
+ dropout_rate: float = 0.0,
+ ) -> None:
+ super().__init__()
+
+ # configure encoder.
+ self.encoder = Encoder(
+ in_channels,
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ embedding_dim,
+ num_embeddings,
+ beta,
+ activation,
+ dropout_rate,
+ )
+
+ # Configure decoder.
+ channels.reverse()
+ kernel_sizes.reverse()
+ strides.reverse()
+ self.decoder = Decoder(
+ channels,
+ kernel_sizes,
+ strides,
+ num_residual_layers,
+ embedding_dim,
+ upsampling,
+ activation,
+ dropout_rate,
+ )
+
+ def encode(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Encodes input to a latent code."""
+ return self.encoder(x)
+
+ def decode(self, z_q: Tensor) -> Tensor:
+ """Reconstructs input from latent codes."""
+ return self.decoder(z_q)
+
+ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
+ """Compresses and decompresses input."""
+ if len(x.shape) < 4:
+ x = x[(None,) * (4 - len(x.shape))]
+ z_q, vq_loss = self.encode(x)
+ x_reconstruction = self.decode(z_q)
+ return x_reconstruction, vq_loss