diff options
author | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
---|---|---|
committer | Gustaf Rydholm <gustaf.rydholm@gmail.com> | 2021-03-20 18:09:06 +0100 |
commit | 7e8e54e84c63171e748bbf09516fd517e6821ace (patch) | |
tree | 996093f75a5d488dddf7ea1f159ed343a561ef89 /src/text_recognizer/networks/vqvae/encoder.py | |
parent | b0719d84138b6bbe5f04a4982dfca673aea1a368 (diff) |
Inital commit for refactoring to lightning
Diffstat (limited to 'src/text_recognizer/networks/vqvae/encoder.py')
-rw-r--r-- | src/text_recognizer/networks/vqvae/encoder.py | 147 |
1 files changed, 0 insertions, 147 deletions
diff --git a/src/text_recognizer/networks/vqvae/encoder.py b/src/text_recognizer/networks/vqvae/encoder.py deleted file mode 100644 index d3adac5..0000000 --- a/src/text_recognizer/networks/vqvae/encoder.py +++ /dev/null @@ -1,147 +0,0 @@ -"""CNN encoder for the VQ-VAE.""" -from typing import List, Optional, Tuple, Type - -import torch -from torch import nn -from torch import Tensor - -from text_recognizer.networks.util import activation_function -from text_recognizer.networks.vqvae.vector_quantizer import VectorQuantizer - - -class _ResidualBlock(nn.Module): - def __init__( - self, in_channels: int, out_channels: int, dropout: Optional[Type[nn.Module]], - ) -> None: - super().__init__() - self.block = [ - nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), - nn.ReLU(inplace=True), - nn.Conv2d(out_channels, out_channels, kernel_size=1, bias=False), - ] - - if dropout is not None: - self.block.append(dropout) - - self.block = nn.Sequential(*self.block) - - def forward(self, x: Tensor) -> Tensor: - """Apply the residual forward pass.""" - return x + self.block(x) - - -class Encoder(nn.Module): - """A CNN encoder network.""" - - def __init__( - self, - in_channels: int, - channels: List[int], - kernel_sizes: List[int], - strides: List[int], - num_residual_layers: int, - embedding_dim: int, - num_embeddings: int, - beta: float = 0.25, - activation: str = "leaky_relu", - dropout_rate: float = 0.0, - ) -> None: - super().__init__() - - if dropout_rate: - if activation == "selu": - dropout = nn.AlphaDropout(p=dropout_rate) - else: - dropout = nn.Dropout(p=dropout_rate) - else: - dropout = None - - self.embedding_dim = embedding_dim - self.num_embeddings = num_embeddings - self.beta = beta - activation = activation_function(activation) - - # Configure encoder. - self.encoder = self._build_encoder( - in_channels, - channels, - kernel_sizes, - strides, - num_residual_layers, - activation, - dropout, - ) - - # Configure Vector Quantizer. - self.vector_quantizer = VectorQuantizer( - self.num_embeddings, self.embedding_dim, self.beta - ) - - def _build_compression_block( - self, - in_channels: int, - channels: int, - kernel_sizes: List[int], - strides: List[int], - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], - ) -> nn.ModuleList: - modules = nn.ModuleList([]) - configuration = zip(channels, kernel_sizes, strides) - for out_channels, kernel_size, stride in configuration: - modules.append( - nn.Sequential( - nn.Conv2d( - in_channels, out_channels, kernel_size, stride=stride, padding=1 - ), - activation, - ) - ) - - if dropout is not None: - modules.append(dropout) - - in_channels = out_channels - - return modules - - def _build_encoder( - self, - in_channels: int, - channels: int, - kernel_sizes: List[int], - strides: List[int], - num_residual_layers: int, - activation: Type[nn.Module], - dropout: Optional[Type[nn.Module]], - ) -> nn.Sequential: - encoder = nn.ModuleList([]) - - # compression module - encoder.extend( - self._build_compression_block( - in_channels, channels, kernel_sizes, strides, activation, dropout - ) - ) - - # Bottleneck module. - encoder.extend( - nn.ModuleList( - [ - _ResidualBlock(channels[-1], channels[-1], dropout) - for i in range(num_residual_layers) - ] - ) - ) - - encoder.append( - nn.Conv2d(channels[-1], self.embedding_dim, kernel_size=1, stride=1,) - ) - - return nn.Sequential(*encoder) - - def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]: - """Encodes input into a discrete representation.""" - z_e = self.encoder(x) - z_q, vq_loss = self.vector_quantizer(z_e) - return z_q, vq_loss |